Skip to content

Commit 034256c

Browse files
committed
remove unsafe
1 parent 07d1254 commit 034256c

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

benches/micro.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn bench_compress(c: &mut Criterion) {
2121
let compressor = compressor.build();
2222

2323
let word = u64::from_le_bytes([b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']);
24-
b.iter(|| unsafe { compressor.compress_word(word, output_buf.as_mut_ptr()) });
24+
b.iter(|| unsafe { compressor.compress_word(word, output_buf.spare_capacity_mut()) });
2525
});
2626

2727
// We create a symbol table that is able to short-circuit the decompression
@@ -31,7 +31,7 @@ fn bench_compress(c: &mut Criterion) {
3131
let compressor = compressor.build();
3232

3333
let word = u64::from_le_bytes([b'a', b'b', 0, 0, 0, 0, 0, 0]);
34-
b.iter(|| unsafe { compressor.compress_word(word, output_buf.as_mut_ptr()) });
34+
b.iter(|| unsafe { compressor.compress_word(word, output_buf.spare_capacity_mut()) });
3535
});
3636
group.finish();
3737

src/lib.rs

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![allow(unsafe_op_in_unsafe_fn)]
1+
// #![allow(unsafe_op_in_unsafe_fn)]
22
#![doc = include_str!("../README.md")]
33
#![cfg(target_endian = "little")]
44

@@ -558,25 +558,30 @@ impl Compressor {
558558
///
559559
/// `advance_in` is the number of bytes to advance the input pointer before the next call.
560560
///
561-
/// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
561+
/// `advance_out` is the number of bytes to advance `out_ptr` before the next call. Will
562+
/// be either 1 (if a code is emitted) or 2 (if a symbol is emitted).
562563
///
563564
/// # Safety
564565
///
565566
/// `out_ptr` must never be NULL or otherwise point to invalid memory.
566-
pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
567+
pub unsafe fn compress_word(
568+
&self,
569+
word: u64,
570+
out_ptr: &mut [MaybeUninit<u8>],
571+
) -> (usize, usize) {
567572
// Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
568573
// if it isn't, it will be overwritten anyway.
569574
//
570575
// SAFETY: caller ensures out_ptr is not null
571576
let first_byte = word as u8;
572-
out_ptr.byte_add(1).write_unaligned(first_byte);
577+
out_ptr[1].write(first_byte);
573578

574579
// First, check the two_bytes table
575580
let code_twobyte = self.codes_two_byte[word as u16 as usize];
576581

577582
if code_twobyte.code() < self.has_suffix_code {
578583
// 2 byte code without having to worry about longer matches.
579-
std::ptr::write(out_ptr, code_twobyte.code());
584+
out_ptr[0].write(code_twobyte.code());
580585

581586
// Advance input by symbol length (2) and output by a single code byte
582587
(2, 1)
@@ -590,10 +595,10 @@ impl Compressor {
590595
&& compare_masked(word, entry.symbol.as_u64(), ignored_bits)
591596
{
592597
// Advance the input by the symbol length (variable) and the output by one code byte
593-
std::ptr::write(out_ptr, entry.code.code());
598+
out_ptr[0].write(entry.code.code());
594599
(entry.code.len() as usize, 1)
595600
} else {
596-
std::ptr::write(out_ptr, code_twobyte.code());
601+
out_ptr[0].write(code_twobyte.code());
597602

598603
// Advance the input by the symbol length (variable) and the output by either 1
599604
// byte (if was one-byte code) or two bytes (escape).
@@ -655,47 +660,47 @@ impl Compressor {
655660
/// all encoded data.
656661
pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
657662
let mut in_ptr = plaintext.as_ptr();
658-
let mut out_ptr = values.as_mut_ptr();
663+
// let mut out_ptr = values.as_mut_ptr();
664+
let out_values = values.spare_capacity_mut();
665+
let mut out_ptr = 0;
659666

660667
// SAFETY: `end` will point just after the end of the `plaintext` slice.
661668
let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
662669
let in_end_sub8 = in_end as usize - 8;
663-
// SAFETY: `end` will point just after the end of the `values` allocation.
664-
let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
665670

666-
while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end {
671+
while (in_ptr as usize) <= in_end_sub8 {
667672
// SAFETY: pointer ranges are checked in the loop condition
668673
unsafe {
669674
// Load a full 8-byte word of data from in_ptr.
670-
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
675+
// SAFETY: we check above that in_ptr points to at least 8 bytes of valid allocation
671676
let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
672-
let (advance_in, advance_out) = self.compress_word(word, out_ptr);
677+
let (advance_in, advance_out) =
678+
self.compress_word(word, &mut out_values[out_ptr..]);
673679
in_ptr = in_ptr.byte_add(advance_in);
674-
out_ptr = out_ptr.byte_add(advance_out);
680+
out_ptr += advance_out;
675681
};
676682
}
677683

678684
let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
679-
assert!(
680-
out_ptr < out_end || remaining_bytes == 0,
681-
"output buffer sized too small"
682-
);
683685

684686
let remaining_bytes = remaining_bytes as usize;
685687

686688
// Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
687689
// but shift data out of this word rather than advancing an input pointer and potentially reading
688690
// unowned memory.
689691
let mut bytes = [0u8; 8];
690-
std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
692+
// SAFETY: we know that `remaining_bytes` <= 8.
693+
unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
691694
let mut last_word = u64::from_le_bytes(bytes);
692695

693-
while in_ptr < in_end && out_ptr < out_end {
696+
while in_ptr < in_end {
694697
// Load a full 8-byte word of data from in_ptr.
695-
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
696-
let (advance_in, advance_out) = self.compress_word(last_word, out_ptr);
697-
in_ptr = in_ptr.byte_add(advance_in);
698-
out_ptr = out_ptr.byte_add(advance_out);
698+
// SAFETY: we check that the out_ptr is not more than 2 bytes from the end of the allocation.
699+
let (advance_in, advance_out) =
700+
unsafe { self.compress_word(last_word, &mut out_values[out_ptr..]) };
701+
in_ptr = unsafe { in_ptr.byte_add(advance_in) };
702+
// out_values = out_values.byte_add(advance_out);
703+
out_ptr += advance_out;
699704

700705
last_word = advance_8byte_word(last_word, advance_in);
701706
}
@@ -708,13 +713,9 @@ impl Compressor {
708713

709714
// Count the number of bytes written
710715
// SAFETY: assertion
711-
let bytes_written = out_ptr.offset_from(values.as_ptr());
712-
assert!(
713-
bytes_written >= 0,
714-
"out_ptr ended before it started, not possible"
715-
);
716+
let bytes_written = out_ptr;
716717

717-
values.set_len(bytes_written as usize);
718+
unsafe { values.set_len(bytes_written) };
718719
}
719720

720721
/// Use the symbol table to compress the plaintext into a sequence of codes and escapes.

0 commit comments

Comments
 (0)