Skip to content

Commit 21b759a

Browse files
committed
Don't allocate trailing uninit bits in the InitMap of CTFE Allocations
1 parent 0ac4658 commit 21b759a

File tree

2 files changed

+68
-38
lines changed

2 files changed

+68
-38
lines changed

compiler/rustc_middle/src/mir/interpret/allocation.rs

+67-37
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ impl<Tag> Allocation<Tag> {
147147
Self {
148148
bytes,
149149
relocations: Relocations::new(),
150-
init_mask: InitMask::new(size, true),
150+
init_mask: InitMask::new_init(size),
151151
align,
152152
mutability,
153153
extra: (),
@@ -180,7 +180,7 @@ impl<Tag> Allocation<Tag> {
180180
Ok(Allocation {
181181
bytes,
182182
relocations: Relocations::new(),
183-
init_mask: InitMask::new(size, false),
183+
init_mask: InitMask::new_uninit(size),
184184
align,
185185
mutability: Mutability::Mut,
186186
extra: (),
@@ -629,15 +629,19 @@ impl InitMask {
629629
Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
630630
}
631631

632-
pub fn new(size: Size, state: bool) -> Self {
632+
pub fn new_init(size: Size) -> Self {
633633
let mut m = InitMask { blocks: vec![], len: Size::ZERO };
634-
m.grow(size, state);
634+
m.grow(size, true);
635635
m
636636
}
637637

638+
pub fn new_uninit(size: Size) -> Self {
639+
InitMask { blocks: vec![], len: size }
640+
}
641+
638642
pub fn set_range(&mut self, start: Size, end: Size, new_state: bool) {
639643
let len = self.len;
640-
if end > len {
644+
if end > len && new_state {
641645
self.grow(end - len, new_state);
642646
}
643647
self.set_range_inbounds(start, end, new_state);
@@ -655,14 +659,16 @@ impl InitMask {
655659
(u64::MAX << bita) & (u64::MAX >> (64 - bitb))
656660
};
657661
if new_state {
662+
self.ensure_blocks(blocka);
658663
self.blocks[blocka] |= range;
659-
} else {
660-
self.blocks[blocka] &= !range;
664+
} else if let Some(block) = self.blocks.get_mut(blocka) {
665+
*block &= !range;
661666
}
662667
return;
663668
}
664669
// across block boundaries
665670
if new_state {
671+
self.ensure_blocks(blockb);
666672
// Set `bita..64` to `1`.
667673
self.blocks[blocka] |= u64::MAX << bita;
668674
// Set `0..bitb` to `1`.
@@ -673,15 +679,17 @@ impl InitMask {
673679
for block in (blocka + 1)..blockb {
674680
self.blocks[block] = u64::MAX;
675681
}
676-
} else {
682+
} else if let Some(blocka_val) = self.blocks.get_mut(blocka) {
677683
// Set `bita..64` to `0`.
678-
self.blocks[blocka] &= !(u64::MAX << bita);
684+
*blocka_val &= !(u64::MAX << bita);
679685
// Set `0..bitb` to `0`.
680686
if bitb != 0 {
681-
self.blocks[blockb] &= !(u64::MAX >> (64 - bitb));
687+
if let Some(blockb_val) = self.blocks.get_mut(blockb) {
688+
*blockb_val &= !(u64::MAX >> (64 - bitb));
689+
}
682690
}
683691
// Fill in all the other blocks (much faster than one bit at a time).
684-
for block in (blocka + 1)..blockb {
692+
for block in (blocka + 1)..std::cmp::min(blockb, self.blocks.len()) {
685693
self.blocks[block] = 0;
686694
}
687695
}
@@ -690,7 +698,10 @@ impl InitMask {
690698
#[inline]
691699
pub fn get(&self, i: Size) -> bool {
692700
let (block, bit) = Self::bit_index(i);
693-
(self.blocks[block] & (1 << bit)) != 0
701+
match self.blocks.get(block) {
702+
Some(block) => (*block & (1 << bit)) != 0,
703+
None => false,
704+
}
694705
}
695706

696707
#[inline]
@@ -702,10 +713,22 @@ impl InitMask {
702713
#[inline]
703714
fn set_bit(&mut self, block: usize, bit: usize, new_state: bool) {
704715
if new_state {
716+
self.ensure_blocks(block);
705717
self.blocks[block] |= 1 << bit;
706-
} else {
707-
self.blocks[block] &= !(1 << bit);
718+
} else if let Some(block) = self.blocks.get_mut(block) {
719+
*block &= !(1 << bit);
720+
}
721+
}
722+
723+
fn ensure_blocks(&mut self, block: usize) {
724+
if block < self.blocks.len() {
725+
return;
708726
}
727+
let additional_blocks = block - self.blocks.len() + 1;
728+
self.blocks.extend(
729+
// FIXME(oli-obk): optimize this by repeating `new_state as Block`.
730+
iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()),
731+
);
709732
}
710733

711734
pub fn grow(&mut self, amount: Size, new_state: bool) {
@@ -716,10 +739,7 @@ impl InitMask {
716739
u64::try_from(self.blocks.len()).unwrap() * Self::BLOCK_SIZE - self.len.bytes();
717740
if amount.bytes() > unused_trailing_bits {
718741
let additional_blocks = amount.bytes() / Self::BLOCK_SIZE + 1;
719-
self.blocks.extend(
720-
// FIXME(oli-obk): optimize this by repeating `new_state as Block`.
721-
iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()),
722-
);
742+
self.ensure_blocks(self.blocks.len() + additional_blocks as usize - 1);
723743
}
724744
let start = self.len;
725745
self.len += amount;
@@ -821,25 +841,31 @@ impl InitMask {
821841
// (c) 01000000|00000000|00000001
822842
// ^~~~~~~~~~~~~~~~~~^
823843
// start end
824-
if let Some(i) =
825-
search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
826-
{
827-
// If the range is less than a block, we may find a matching bit after `end`.
828-
//
829-
// For example, we shouldn't successfully find bit (2), because it's after `end`:
830-
//
831-
// (2)
832-
// -------|
833-
// (d) 00000001|00000000|00000001
834-
// ^~~~~^
835-
// start end
836-
//
837-
// An alternative would be to mask off end bits in the same way as we do for start bits,
838-
// but performing this check afterwards is faster and simpler to implement.
839-
if i < end {
840-
return Some(i);
841-
} else {
844+
if let Some(&bits) = init_mask.blocks.get(start_block) {
845+
if let Some(i) = search_block(bits, start_block, start_bit, is_init) {
846+
// If the range is less than a block, we may find a matching bit after `end`.
847+
//
848+
// For example, we shouldn't successfully find bit (2), because it's after `end`:
849+
//
850+
// (2)
851+
// -------|
852+
// (d) 00000001|00000000|00000001
853+
// ^~~~~^
854+
// start end
855+
//
856+
// An alternative would be to mask off end bits in the same way as we do for start bits,
857+
// but performing this check afterwards is faster and simpler to implement.
858+
if i < end {
859+
return Some(i);
860+
} else {
861+
return None;
862+
}
863+
}
864+
} else {
865+
if is_init {
842866
return None;
867+
} else {
868+
return Some(start);
843869
}
844870
}
845871

@@ -861,7 +887,8 @@ impl InitMask {
861887
// because both alternatives result in significantly worse codegen.
862888
// `end_block_inclusive + 1` is guaranteed not to wrap, because `end_block_inclusive <= end / BLOCK_SIZE`,
863889
// and `BLOCK_SIZE` (the number of bits per block) will always be at least 8 (1 byte).
864-
for (&bits, block) in init_mask.blocks[start_block + 1..end_block_inclusive + 1]
890+
for (&bits, block) in init_mask.blocks[start_block + 1
891+
..std::cmp::min(end_block_inclusive + 1, init_mask.blocks.len())]
865892
.iter()
866893
.zip(start_block + 1..)
867894
{
@@ -886,6 +913,9 @@ impl InitMask {
886913
}
887914
}
888915
}
916+
if !is_init && end_block_inclusive >= init_mask.blocks.len() {
917+
return Some(InitMask::size_from_bit_index(init_mask.blocks.len(), 0));
918+
}
889919
}
890920

891921
None

src/test/ui-fulldeps/uninit_mask.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_middle::mir::interpret::InitMask;
1111
use rustc_target::abi::Size;
1212

1313
fn main() {
14-
let mut mask = InitMask::new(Size::from_bytes(500), false);
14+
let mut mask = InitMask::new_uninit(Size::from_bytes(500));
1515
assert!(!mask.get(Size::from_bytes(499)));
1616
mask.set(Size::from_bytes(499), true);
1717
assert!(mask.get(Size::from_bytes(499)));

0 commit comments

Comments
 (0)