Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion libbz2-rs-sys/src/bzlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,50 @@ mod stream {
unsafe { Allocator::from_bz_stream(self) }
}

#[must_use]
#[inline(always)]
pub(crate) fn pull_u32(
&mut self,
mut bit_buffer: u64,
bits_used: i32,
) -> Option<(u64, i32)> {
if self.avail_in < 4 {
return None;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a debug assertion that no more than 7 bytes are requested? And no more than 1 byte in the case of pull_u8.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I added an assert that we can at least consume 1 additional byte, in other words that we don't ask for more input if we don't have space in the bit buffer.

The check you comment on is a bounds check, so we just skip this function if there are not 8 additional bytes in the input, there is nothing to assert there right?

// of course this uses big endian values
let read = unsafe { self.next_in.cast::<u32>().read_unaligned().to_be() };

// because of the endianness, we can only shift in whole bytes.
let increment_bytes = (31 - bits_used) / 8;
let increment_bits = 8 * increment_bytes;

bit_buffer <<= increment_bits;
bit_buffer |= (read >> (32 - increment_bits)) as u64;

self.next_in = unsafe { (self.next_in).add(increment_bytes as usize) };
self.avail_in -= increment_bytes as u32;

// skips updating `self.total_in`: the caller is responsible for keeping it updated

Some((bit_buffer, bits_used + increment_bits))
}

#[must_use]
#[inline(always)]
pub(crate) fn read_byte_fast(&mut self) -> Option<u8> {
if self.avail_in == 0 {
return None;
}
let b = unsafe { *(self.next_in as *mut u8) };
self.next_in = unsafe { (self.next_in).offset(1) };
self.avail_in -= 1;

// skips updating `self.total_in`: the caller is responsible for keeping it updated

Some(b)
}

#[must_use]
pub(crate) fn read_byte(&mut self) -> Option<u8> {
if self.avail_in == 0 {
Expand Down Expand Up @@ -533,7 +577,7 @@ pub(crate) struct DState {
pub k0: u8,
pub rNToGo: i32,
pub rTPos: i32,
pub bsBuff: u32,
pub bsBuff: u64,
pub bsLive: i32,
pub smallDecompress: DecompressMode,
pub currBlockNo: i32,
Expand Down
19 changes: 15 additions & 4 deletions libbz2-rs-sys/src/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ pub(crate) fn decompress(
let mut current_block: Block;
let mut uc: u8;

let old_avail_in = strm.avail_in;

if let State::BZ_X_MAGIC_1 = s.state {
/*zero out the save area*/
s.save = SaveArea::default();
Expand Down Expand Up @@ -180,13 +182,16 @@ pub(crate) fn decompress(
($strm:expr, $s:expr, $nnn:expr) => {
loop {
if $s.bsLive >= $nnn {
let v: u32 = ($s.bsBuff >> ($s.bsLive - $nnn)) & ((1 << $nnn) - 1);
let v: u64 = ($s.bsBuff >> ($s.bsLive - $nnn)) & ((1 << $nnn) - 1);
$s.bsLive -= $nnn;
break v;
break v as u32;
}

if let Some(next_byte) = strm.read_byte() {
$s.bsBuff = $s.bsBuff << 8 | next_byte as u32;
if let Some((bit_buffer, bits_used)) = strm.pull_u32($s.bsBuff, $s.bsLive) {
$s.bsBuff = bit_buffer;
$s.bsLive = bits_used;
} else if let Some(next_byte) = strm.read_byte_fast() {
$s.bsBuff = $s.bsBuff << 8 | next_byte as u64;
$s.bsLive += 8;
} else {
break 'save_state_and_return ReturnCode::BZ_OK;
Expand Down Expand Up @@ -1165,6 +1170,12 @@ pub(crate) fn decompress(
gMinlen,
};

// update total_in with how many bytes were read during this call
let bytes_read = old_avail_in - strm.avail_in;
let old_total_in_lo32 = strm.total_in_lo32;
strm.total_in_lo32 = strm.total_in_lo32.wrapping_add(bytes_read);
strm.total_in_hi32 += (strm.total_in_lo32 < old_total_in_lo32) as u32;

ret_val
}

Expand Down