diff --git a/fuzz/fuzz_targets/compress.rs b/fuzz/fuzz_targets/compress.rs index 1f52651c0..614cdc698 100644 --- a/fuzz/fuzz_targets/compress.rs +++ b/fuzz/fuzz_targets/compress.rs @@ -14,18 +14,14 @@ fuzz_target!(|data: String| { assert_eq!(error, BZ_OK); - let mut output = [0u8; 1 << 10]; - let mut output_len = output.len() as _; - let error = unsafe { - test_libbz2_rs_sys::decompress_rs( - output.as_mut_ptr(), - &mut output_len, + let (error, output) = unsafe { + test_libbz2_rs_sys::decompress_rs_with_capacity( + 1 << 10, deflated.as_ptr(), deflated.len() as _, ) }; assert_eq!(error, BZ_OK); - let output = &output[..output_len as usize]; if output != data.as_bytes() { let path = std::env::temp_dir().join("compressed.txt"); diff --git a/fuzz/fuzz_targets/decompress_random_input.rs b/fuzz/fuzz_targets/decompress_random_input.rs index e89649646..8b36a05dd 100644 --- a/fuzz/fuzz_targets/decompress_random_input.rs +++ b/fuzz/fuzz_targets/decompress_random_input.rs @@ -3,36 +3,17 @@ use libbz2_rs_sys::BZ_OK; use libfuzzer_sys::fuzz_target; fuzz_target!(|source: Vec| { - let mut dest_c = vec![0u8; 1 << 16]; - let mut dest_rs = vec![0u8; 1 << 16]; - - let mut dest_len_c = dest_c.len() as _; - let mut dest_len_rs = dest_rs.len() as _; - - let err_c = unsafe { - test_libbz2_rs_sys::decompress_c( - dest_c.as_mut_ptr(), - &mut dest_len_c, - source.as_ptr(), - source.len() as _, - ) + let (err_c, dest_c) = unsafe { + test_libbz2_rs_sys::decompress_c_with_capacity(1 << 16, source.as_ptr(), source.len() as _) }; - let err_rs = unsafe { - test_libbz2_rs_sys::decompress_rs( - dest_rs.as_mut_ptr(), - &mut dest_len_rs, - source.as_ptr(), - source.len() as _, - ) + let (err_rs, dest_rs) = unsafe { + test_libbz2_rs_sys::decompress_rs_with_capacity(1 << 16, source.as_ptr(), source.len() as _) }; assert_eq!(err_c, err_rs); if err_c == BZ_OK { - dest_c.truncate(dest_len_c as usize); - dest_rs.truncate(dest_len_rs as usize); - assert_eq!(dest_c, dest_rs); } }); diff --git a/fuzz/fuzz_targets/end_to_end.rs b/fuzz/fuzz_targets/end_to_end.rs index 552c0e8e1..0c99b5d6c 100644 --- a/fuzz/fuzz_targets/end_to_end.rs +++ b/fuzz/fuzz_targets/end_to_end.rs @@ -3,22 +3,16 @@ use libbz2_rs_sys::BZ_OK; use libfuzzer_sys::fuzz_target; fn decompress_help(input: &[u8]) -> Vec { - let mut dest_vec = vec![0u8; 1 << 16]; - - let mut dest_len = dest_vec.len() as _; - let dest = dest_vec.as_mut_ptr(); - let source = input.as_ptr(); let source_len = input.len() as _; - let err = unsafe { test_libbz2_rs_sys::decompress_rs(dest, &mut dest_len, source, source_len) }; + let (err, dest_vec) = + unsafe { test_libbz2_rs_sys::decompress_rs_with_capacity(1 << 16, source, source_len) }; if err != BZ_OK { panic!("error {:?}", err); } - dest_vec.truncate(dest_len as usize); - dest_vec } diff --git a/test-libbz2-rs-sys/examples/decompress.rs b/test-libbz2-rs-sys/examples/decompress.rs index 3b8906d4c..f21281279 100644 --- a/test-libbz2-rs-sys/examples/decompress.rs +++ b/test-libbz2-rs-sys/examples/decompress.rs @@ -1,6 +1,4 @@ -use core::ffi::c_uint; - -use test_libbz2_rs_sys::{decompress_c, decompress_rs}; +use test_libbz2_rs_sys::{decompress_c_with_capacity, decompress_rs_with_capacity}; fn main() { let mut it = std::env::args(); @@ -12,44 +10,32 @@ fn main() { let path = it.next().unwrap(); let input = std::fs::read(&path).unwrap(); - let mut dest_vec = vec![0u8; 1 << 28]; - - let mut dest_len = dest_vec.len() as c_uint; - let dest = dest_vec.as_mut_ptr(); - let source = input.as_ptr(); let source_len = input.len() as _; - let err = unsafe { decompress_c(dest, &mut dest_len, source, source_len) }; + let (err, dest_vec) = + unsafe { decompress_c_with_capacity(1 << 28, source, source_len) }; if err != 0 { panic!("error {err}"); } - dest_vec.truncate(dest_len as usize); - drop(dest_vec) } "rs" => { let path = it.next().unwrap(); let input = std::fs::read(&path).unwrap(); - let mut dest_vec = vec![0u8; 1 << 28]; - - let mut dest_len = dest_vec.len() as std::ffi::c_uint; - let dest = dest_vec.as_mut_ptr(); - let source = input.as_ptr(); let source_len = input.len() as _; - let err = unsafe { decompress_rs(dest, &mut dest_len, source, source_len) }; + let (err, dest_vec) = + unsafe { decompress_rs_with_capacity(1 << 28, source, source_len) }; if err != 0 { panic!("error {err}"); } - dest_vec.truncate(dest_len as usize); - drop(dest_vec) } other => panic!("invalid option '{other}', expected one of 'c' or 'rs'"), diff --git a/test-libbz2-rs-sys/src/chunked.rs b/test-libbz2-rs-sys/src/chunked.rs index 81ee8b863..75060e828 100644 --- a/test-libbz2-rs-sys/src/chunked.rs +++ b/test-libbz2-rs-sys/src/chunked.rs @@ -1,4 +1,4 @@ -use crate::{compress_c, decompress_c, SAMPLE1_BZ2, SAMPLE1_REF}; +use crate::{compress_c, decompress_c, decompress_c_with_capacity, SAMPLE1_BZ2, SAMPLE1_REF}; fn decompress_rs_chunked_input<'a>( dest: &'a mut [u8], @@ -55,18 +55,10 @@ fn decompress_chunked_input() { let chunked = decompress_rs_chunked_input(&mut dest_chunked, SAMPLE1_BZ2, 1).unwrap(); if !cfg!(miri) { - let mut dest = vec![0; 1 << 18]; - let mut dest_len = dest.len() as _; - let err = unsafe { - decompress_c( - dest.as_mut_ptr(), - &mut dest_len, - SAMPLE1_BZ2.as_ptr(), - SAMPLE1_BZ2.len() as _, - ) + let (err, dest) = unsafe { + decompress_c_with_capacity(1 << 18, SAMPLE1_BZ2.as_ptr(), SAMPLE1_BZ2.len() as _) }; assert_eq!(err, 0); - dest.truncate(dest_len as usize); assert_eq!(chunked.len(), dest.len()); assert_eq!(chunked, dest); @@ -201,18 +193,8 @@ fn decompress_rs_chunked_output<'a>( #[test] fn decompress_chunked_output() { - let mut dest = vec![0; 1 << 18]; - let mut dest_len = dest.len() as _; - let err = unsafe { - decompress_c( - dest.as_mut_ptr(), - &mut dest_len, - SAMPLE1_BZ2.as_ptr(), - SAMPLE1_BZ2.len() as _, - ) - }; + let (err, dest) = unsafe { decompress_c(SAMPLE1_BZ2.as_ptr(), SAMPLE1_BZ2.len() as _) }; assert_eq!(err, 0); - dest.truncate(dest_len as usize); let mut dest_chunked = vec![0; 1 << 18]; let chunked = decompress_rs_chunked_input(&mut dest_chunked, SAMPLE1_BZ2, 1).unwrap(); diff --git a/test-libbz2-rs-sys/src/lib.rs b/test-libbz2-rs-sys/src/lib.rs index f18608501..99d750a1a 100644 --- a/test-libbz2-rs-sys/src/lib.rs +++ b/test-libbz2-rs-sys/src/lib.rs @@ -27,7 +27,9 @@ macro_rules! assert_eq_rs_c { let _ng = unsafe { use bzip2_sys::*; use compress_c as compress; + use compress_c_with_capacity as compress_with_capacity; use decompress_c as decompress; + use decompress_c_with_capacity as decompress_with_capacity; $tt }; @@ -35,7 +37,9 @@ macro_rules! assert_eq_rs_c { #[allow(clippy::macro_metavars_in_unsafe)] let _rs = unsafe { use compress_rs as compress; + use compress_rs_with_capacity as compress_with_capacity; use decompress_rs as decompress; + use decompress_rs_with_capacity as decompress_with_capacity; use libbz2_rs_sys::*; $tt @@ -53,19 +57,7 @@ macro_rules! assert_eq_decompress { let input = include_bytes!($input); assert_eq_rs_c!({ - let mut dest = vec![0; 2 * input.len()]; - let mut dest_len = dest.len() as core::ffi::c_uint; - - decompress( - dest.as_mut_ptr(), - &mut dest_len, - input.as_ptr(), - input.len() as core::ffi::c_uint, - ); - - dest.truncate(dest_len as usize); - - dest + decompress_with_capacity(1 << 28, input.as_ptr(), input.len() as core::ffi::c_uint) }); }; } @@ -74,7 +66,9 @@ macro_rules! assert_eq_compress { ($input:literal) => { let input = include_bytes!($input); - assert_eq_rs_c!({ compress(input.as_ptr(), input.len() as core::ffi::c_uint, 9) }); + assert_eq_rs_c!({ + compress_with_capacity(1 << 28, input.as_ptr(), input.len() as core::ffi::c_uint, 9) + }); }; } @@ -240,12 +234,15 @@ fn miri_compress_sample3() { assert_eq_compress!("../../tests/input/quick/sample3.bz2"); } -pub unsafe fn decompress_c( - dest: *mut u8, - dest_len: *mut libc::c_uint, +unsafe fn decompress_c(source: *const u8, source_len: libc::c_uint) -> (i32, Vec) { + decompress_c_with_capacity(1024, source, source_len) +} + +pub unsafe fn decompress_c_with_capacity( + capacity: usize, source: *const u8, source_len: libc::c_uint, -) -> i32 { +) -> (i32, Vec) { use bzip2_sys::*; let mut strm: bz_stream = bz_stream { @@ -262,48 +259,71 @@ pub unsafe fn decompress_c( bzfree: None, opaque: std::ptr::null_mut::(), }; - let mut ret: libc::c_int; - if dest.is_null() || dest_len.is_null() || source.is_null() { - return -(2 as libc::c_int); - } + + let mut dest = vec![0u8; capacity]; + strm.bzalloc = None; strm.bzfree = None; strm.opaque = std::ptr::null_mut::(); unsafe { - ret = BZ2_bzDecompressInit(&mut strm, 0, 0); - if ret != 0 as libc::c_int { - return ret; + let ret = BZ2_bzDecompressInit(&mut strm, 0, 0); + if ret != 0 { + return (ret, vec![]); } - strm.next_in = source as *mut libc::c_char; - strm.next_out = dest.cast::(); strm.avail_in = source_len; - strm.avail_out = *dest_len; - ret = BZ2_bzDecompress(&mut strm); - if ret == 0 as libc::c_int { - if strm.avail_out > 0 as libc::c_int as libc::c_uint { - BZ2_bzDecompressEnd(&mut strm); - -(7 as libc::c_int) - } else { - BZ2_bzDecompressEnd(&mut strm); - -(8 as libc::c_int) + strm.avail_out = dest.len() as _; + strm.next_in = source as *mut libc::c_char; + strm.next_out = dest.as_mut_ptr().cast::(); + + let ret = loop { + match BZ2_bzDecompress(&mut strm) { + BZ_OK => { + if strm.avail_out > 0 { + BZ2_bzDecompressEnd(&mut strm); + break BZ_UNEXPECTED_EOF; + } else { + let used = dest.len() - strm.avail_out as usize; + // The dest buffer is full. + let add_space: u32 = Ord::max(1024, dest.len().try_into().unwrap()); + dest.resize(dest.len() + add_space as usize, 0); + + // If resize() reallocates, it may have moved in memory. + strm.next_out = dest.as_mut_ptr().cast::().wrapping_add(used); + strm.avail_out += add_space; + + continue; + } + } + BZ_STREAM_END => { + BZ2_bzDecompressEnd(&mut strm); + break BZ_OK; + } + ret => { + BZ2_bzDecompressEnd(&mut strm); + break ret; + } } - } else if ret != 4 as libc::c_int { - BZ2_bzDecompressEnd(&mut strm); - return ret; - } else { - *dest_len = (*dest_len).wrapping_sub(strm.avail_out); - BZ2_bzDecompressEnd(&mut strm); - return 0 as libc::c_int; - } + }; + + dest.truncate( + ((u64::from(strm.total_out_hi32) << 32) + u64::from(strm.total_out_lo32)) + .try_into() + .unwrap(), + ); + + (ret, dest) } } -pub unsafe fn decompress_rs( - dest: *mut u8, - dest_len: *mut libc::c_uint, +unsafe fn decompress_rs(source: *const u8, source_len: libc::c_uint) -> (i32, Vec) { + decompress_rs_with_capacity(1024, source, source_len) +} + +pub unsafe fn decompress_rs_with_capacity( + capacity: usize, source: *const u8, source_len: libc::c_uint, -) -> i32 { +) -> (i32, Vec) { use libbz2_rs_sys::*; let mut strm: bz_stream = bz_stream { @@ -320,39 +340,59 @@ pub unsafe fn decompress_rs( bzfree: None, opaque: std::ptr::null_mut::(), }; - let mut ret: libc::c_int; - if dest.is_null() || dest_len.is_null() || source.is_null() { - return -(2 as libc::c_int); - } + + let mut dest = vec![0u8; capacity]; + strm.bzalloc = None; strm.bzfree = None; strm.opaque = std::ptr::null_mut::(); unsafe { - ret = BZ2_bzDecompressInit(&mut strm, 0, 0); - if ret != 0 as libc::c_int { - return ret; + let ret = BZ2_bzDecompressInit(&mut strm, 0, 0); + if ret != 0 { + return (ret, vec![]); } - strm.next_in = source as *mut libc::c_char; - strm.next_out = dest.cast::(); strm.avail_in = source_len; - strm.avail_out = *dest_len; - ret = BZ2_bzDecompress(&mut strm); - if ret == 0 as libc::c_int { - if strm.avail_out > 0 as libc::c_int as libc::c_uint { - BZ2_bzDecompressEnd(&mut strm); - -(7 as libc::c_int) - } else { - BZ2_bzDecompressEnd(&mut strm); - -(8 as libc::c_int) + strm.avail_out = dest.len() as _; + strm.next_in = source as *mut libc::c_char; + strm.next_out = dest.as_mut_ptr().cast::(); + + let ret = loop { + match BZ2_bzDecompress(&mut strm) { + BZ_OK => { + if strm.avail_out > 0 { + BZ2_bzDecompressEnd(&mut strm); + break BZ_UNEXPECTED_EOF; + } else { + let used = dest.len() - strm.avail_out as usize; + // The dest buffer is full. + let add_space: u32 = Ord::max(1024, dest.len().try_into().unwrap()); + dest.resize(dest.len() + add_space as usize, 0); + + // If resize() reallocates, it may have moved in memory. + strm.next_out = dest.as_mut_ptr().cast::().wrapping_add(used); + strm.avail_out += add_space; + + continue; + } + } + BZ_STREAM_END => { + BZ2_bzDecompressEnd(&mut strm); + break BZ_OK; + } + ret => { + BZ2_bzDecompressEnd(&mut strm); + break ret; + } } - } else if ret != 4 as libc::c_int { - BZ2_bzDecompressEnd(&mut strm); - return ret; - } else { - *dest_len = (*dest_len).wrapping_sub(strm.avail_out); - BZ2_bzDecompressEnd(&mut strm); - return 0 as libc::c_int; - } + }; + + dest.truncate( + ((u64::from(strm.total_out_hi32) << 32) + u64::from(strm.total_out_lo32)) + .try_into() + .unwrap(), + ); + + (ret, dest) } } @@ -1220,16 +1260,8 @@ mod high_level_interface { let p = std::env::current_dir().unwrap(); let input = std::fs::read(p.join("../tests/input/quick/sample1.bz2")).unwrap(); - let mut expected = vec![0u8; 256 * 1024]; - let mut expected_len = expected.len() as _; - let err = unsafe { - decompress_c( - expected.as_mut_ptr(), - &mut expected_len, - input.as_ptr(), - input.len() as _, - ) - }; + let (err, expected) = + unsafe { decompress_c_with_capacity(256 * 1024, input.as_ptr(), input.len() as _) }; assert_eq!(err, 0); let p = p.join("../tests/input/quick/sample1.bz2\0"); @@ -1276,7 +1308,7 @@ mod high_level_interface { assert_eq!(bzerror, BZ_OK); - assert_eq!(&expected[..expected_len as usize], output); + assert_eq!(expected, output); } #[test] @@ -1343,8 +1375,9 @@ mod high_level_interface { assert_eq!(bzerror, BZ_OK); - let (err, expected) = - unsafe { compress_c(SAMPLE1_BZ2.as_ptr(), SAMPLE1_BZ2.len() as _, 9) }; + let (err, expected) = unsafe { + compress_c_with_capacity(1 << 18, SAMPLE1_BZ2.as_ptr(), SAMPLE1_BZ2.len() as _, 9) + }; assert_eq!(err, 0); assert_eq!(std::fs::read(p).unwrap(), expected,);