Skip to content

Commit 19ff801

Browse files
committed
Auto merge of rust-lang#121174 - saethlin:codegen-niche-checks, r=<try>
Check for occupied niches This is a replacement for rust-lang#104862 r? `@ghost`
2 parents 915e7eb + 8e99635 commit 19ff801

File tree

24 files changed

+484
-13
lines changed

24 files changed

+484
-13
lines changed

compiler/rustc_codegen_ssa/src/mir/block.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
12791279
) -> MergingSucc {
12801280
debug!("codegen_terminator: {:?}", terminator);
12811281

1282+
if bx.tcx().may_insert_niche_checks() {
1283+
if let mir::TerminatorKind::Return = terminator.kind {
1284+
let op = mir::Operand::Copy(mir::Place::return_place());
1285+
let ty = op.ty(self.mir, bx.tcx());
1286+
let ty = self.monomorphize(ty);
1287+
if let Some(niche) = bx.layout_of(ty).largest_niche {
1288+
self.codegen_niche_check(bx, op, niche, terminator.source_info);
1289+
}
1290+
}
1291+
}
1292+
12821293
let helper = TerminatorCodegenHelper { bb, terminator };
12831294

12841295
let mergeable_succ = || {
@@ -1583,7 +1594,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
15831594
tuple.layout.fields.count()
15841595
}
15851596

1586-
fn get_caller_location(
1597+
pub fn get_caller_location(
15871598
&mut self,
15881599
bx: &mut Bx,
15891600
source_info: mir::SourceInfo,

compiler/rustc_codegen_ssa/src/mir/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod debuginfo;
2121
mod intrinsic;
2222
mod locals;
2323
mod naked_asm;
24+
mod niche_check;
2425
pub mod operand;
2526
pub mod place;
2627
mod rvalue;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
use rustc_abi::BackendRepr;
2+
use rustc_hir::LangItem;
3+
use rustc_middle::mir;
4+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
5+
use rustc_middle::ty::{Mutability, Ty, TyCtxt};
6+
use rustc_span::Span;
7+
use rustc_span::def_id::LOCAL_CRATE;
8+
use rustc_target::abi::{Float, Integer, Niche, Primitive, Scalar, Size, WrappingRange};
9+
use tracing::instrument;
10+
11+
use super::FunctionCx;
12+
use crate::mir::OperandValue;
13+
use crate::mir::place::PlaceValue;
14+
use crate::traits::*;
15+
use crate::{base, common};
16+
17+
pub(super) struct NicheFinder<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> {
18+
pub(super) fx: &'s mut FunctionCx<'a, 'tcx, Bx>,
19+
pub(super) bx: &'s mut Bx,
20+
pub(super) places: Vec<(mir::Operand<'tcx>, Niche)>,
21+
}
22+
23+
impl<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> Visitor<'tcx> for NicheFinder<'s, 'a, 'tcx, Bx> {
24+
fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: mir::Location) {
25+
match rvalue {
26+
mir::Rvalue::Cast(mir::CastKind::Transmute, op, ty) => {
27+
let ty = self.fx.monomorphize(*ty);
28+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
29+
self.places.push((op.clone(), niche));
30+
}
31+
}
32+
_ => self.super_rvalue(rvalue, location),
33+
}
34+
}
35+
36+
fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, _location: mir::Location) {
37+
if let mir::TerminatorKind::Return = terminator.kind {
38+
let op = mir::Operand::Copy(mir::Place::return_place());
39+
let ty = op.ty(self.fx.mir, self.bx.tcx());
40+
let ty = self.fx.monomorphize(ty);
41+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
42+
self.places.push((op, niche));
43+
}
44+
}
45+
}
46+
47+
fn visit_place(
48+
&mut self,
49+
place: &mir::Place<'tcx>,
50+
context: PlaceContext,
51+
_location: mir::Location,
52+
) {
53+
match context {
54+
PlaceContext::NonMutatingUse(
55+
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
56+
) => {}
57+
_ => {
58+
return;
59+
}
60+
}
61+
62+
let ty = place.ty(self.fx.mir, self.bx.tcx()).ty;
63+
let ty = self.fx.monomorphize(ty);
64+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
65+
self.places.push((mir::Operand::Copy(*place), niche));
66+
};
67+
}
68+
}
69+
70+
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
71+
fn value_in_niche(
72+
&mut self,
73+
bx: &mut Bx,
74+
op: crate::mir::OperandRef<'tcx, Bx::Value>,
75+
niche: Niche,
76+
) -> Option<Bx::Value> {
77+
let niche_ty = niche.ty(bx.tcx());
78+
let niche_layout = bx.layout_of(niche_ty);
79+
80+
let (imm, from_scalar, from_backend_ty) = match op.val {
81+
OperandValue::Immediate(imm) => {
82+
let BackendRepr::Scalar(from_scalar) = op.layout.backend_repr else {
83+
unreachable!()
84+
};
85+
let from_backend_ty = bx.backend_type(op.layout);
86+
(imm, from_scalar, from_backend_ty)
87+
}
88+
OperandValue::Pair(first, second) => {
89+
let BackendRepr::ScalarPair(first_scalar, second_scalar) = op.layout.backend_repr
90+
else {
91+
unreachable!()
92+
};
93+
if niche.offset == Size::ZERO {
94+
(first, first_scalar, bx.scalar_pair_element_backend_type(op.layout, 0, true))
95+
} else {
96+
// yolo
97+
(second, second_scalar, bx.scalar_pair_element_backend_type(op.layout, 1, true))
98+
}
99+
}
100+
OperandValue::ZeroSized => unreachable!(),
101+
OperandValue::Ref(PlaceValue { llval: ptr, .. }) => {
102+
// General case: Load the niche primitive via pointer arithmetic.
103+
let niche_ptr_ty = Ty::new_ptr(bx.tcx(), niche_ty, Mutability::Not);
104+
let ptr = bx.pointercast(ptr, bx.backend_type(bx.layout_of(niche_ptr_ty)));
105+
106+
let offset = niche.offset.bytes() / niche_layout.size.bytes();
107+
let niche_backend_ty = bx.backend_type(bx.layout_of(niche_ty));
108+
let ptr = bx.inbounds_gep(niche_backend_ty, ptr, &[bx.const_usize(offset)]);
109+
let value = bx.load(niche_backend_ty, ptr, rustc_target::abi::Align::ONE);
110+
return Some(value);
111+
}
112+
};
113+
114+
// Any type whose ABI is a Scalar bool is turned into an i1, so it cannot contain a value
115+
// outside of its niche.
116+
if from_scalar.is_bool() {
117+
return None;
118+
}
119+
120+
let to_scalar = Scalar::Initialized {
121+
value: niche.value,
122+
valid_range: WrappingRange::full(niche.size(bx.tcx())),
123+
};
124+
let to_backend_ty = bx.backend_type(niche_layout);
125+
if from_backend_ty == to_backend_ty {
126+
return Some(imm);
127+
}
128+
let value = self.transmute_immediate(
129+
bx,
130+
imm,
131+
from_scalar,
132+
from_backend_ty,
133+
to_scalar,
134+
to_backend_ty,
135+
);
136+
Some(value)
137+
}
138+
139+
#[instrument(level = "debug", skip(self, bx))]
140+
pub fn codegen_niche_check(
141+
&mut self,
142+
bx: &mut Bx,
143+
mir_op: mir::Operand<'tcx>,
144+
niche: Niche,
145+
source_info: mir::SourceInfo,
146+
) {
147+
let tcx = bx.tcx();
148+
let op_ty = self.monomorphize(mir_op.ty(self.mir, tcx));
149+
if op_ty == tcx.types.bool {
150+
return;
151+
}
152+
153+
let op = self.codegen_operand(bx, &mir_op);
154+
155+
let Some(value_in_niche) = self.value_in_niche(bx, op, niche) else {
156+
return;
157+
};
158+
let size = niche.size(tcx);
159+
160+
let start = niche.scalar(niche.valid_range.start, bx);
161+
let end = niche.scalar(niche.valid_range.end, bx);
162+
163+
let binop_le = base::bin_op_to_icmp_predicate(mir::BinOp::Le, false);
164+
let binop_ge = base::bin_op_to_icmp_predicate(mir::BinOp::Ge, false);
165+
let is_valid = if niche.valid_range.start == 0 {
166+
bx.icmp(binop_le, value_in_niche, end)
167+
} else if niche.valid_range.end == (u128::MAX >> 128 - size.bits()) {
168+
bx.icmp(binop_ge, value_in_niche, start)
169+
} else {
170+
// We need to check if the value is within a *wrapping* range. We could do this:
171+
// (niche >= start) && (niche <= end)
172+
// But what we're going to actually do is this:
173+
// max = end - start
174+
// (niche - start) <= max
175+
// The latter is much more complicated conceptually, but is actually less operations
176+
// because we can compute max in codegen.
177+
let mut max = niche.valid_range.end.wrapping_sub(niche.valid_range.start);
178+
let size = niche.size(tcx);
179+
if size.bits() < 128 {
180+
let mask = (1 << size.bits()) - 1;
181+
max &= mask;
182+
}
183+
let max_adjusted_allowed_value = niche.scalar(max, bx);
184+
185+
let biased = bx.sub(value_in_niche, start);
186+
bx.icmp(binop_le, biased, max_adjusted_allowed_value)
187+
};
188+
189+
// Create destination blocks, branching on is_valid
190+
let panic = bx.append_sibling_block("panic");
191+
let success = bx.append_sibling_block("success");
192+
bx.cond_br(is_valid, success, panic);
193+
194+
// Switch to the failure block and codegen a call to the panic intrinsic
195+
bx.switch_to_block(panic);
196+
self.set_debug_loc(bx, source_info);
197+
let location = self.get_caller_location(bx, source_info).immediate();
198+
self.codegen_panic(
199+
bx,
200+
niche.lang_item(),
201+
&[value_in_niche, start, end, location],
202+
source_info.span,
203+
);
204+
205+
// Continue codegen in the success block.
206+
bx.switch_to_block(success);
207+
self.set_debug_loc(bx, source_info);
208+
}
209+
210+
#[instrument(level = "debug", skip(self, bx))]
211+
fn codegen_panic(&mut self, bx: &mut Bx, lang_item: LangItem, args: &[Bx::Value], span: Span) {
212+
if bx.tcx().is_compiler_builtins(LOCAL_CRATE) {
213+
bx.abort()
214+
} else {
215+
let (fn_abi, fn_ptr, instance) = common::build_langcall(bx, Some(span), lang_item);
216+
let fn_ty = bx.fn_decl_backend_type(&fn_abi);
217+
let fn_attrs = if bx.tcx().def_kind(self.instance.def_id()).has_codegen_attrs() {
218+
Some(bx.tcx().codegen_fn_attrs(self.instance.def_id()))
219+
} else {
220+
None
221+
};
222+
bx.call(fn_ty, fn_attrs, Some(&fn_abi), fn_ptr, args, None, Some(instance));
223+
}
224+
bx.unreachable();
225+
}
226+
}
227+
228+
trait NicheExt {
229+
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx>;
230+
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size;
231+
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value;
232+
fn lang_item(&self) -> LangItem;
233+
}
234+
235+
impl NicheExt for Niche {
236+
fn lang_item(&self) -> LangItem {
237+
match self.value {
238+
Primitive::Int(Integer::I8, _) => LangItem::PanicOccupiedNicheU8,
239+
Primitive::Int(Integer::I16, _) => LangItem::PanicOccupiedNicheU16,
240+
Primitive::Int(Integer::I32, _) => LangItem::PanicOccupiedNicheU32,
241+
Primitive::Int(Integer::I64, _) => LangItem::PanicOccupiedNicheU64,
242+
Primitive::Int(Integer::I128, _) => LangItem::PanicOccupiedNicheU128,
243+
Primitive::Pointer(_) => LangItem::PanicOccupiedNichePtr,
244+
Primitive::Float(Float::F16) => LangItem::PanicOccupiedNicheU16,
245+
Primitive::Float(Float::F32) => LangItem::PanicOccupiedNicheU32,
246+
Primitive::Float(Float::F64) => LangItem::PanicOccupiedNicheU64,
247+
Primitive::Float(Float::F128) => LangItem::PanicOccupiedNicheU128,
248+
}
249+
}
250+
251+
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
252+
let types = &tcx.types;
253+
match self.value {
254+
Primitive::Int(Integer::I8, _) => types.u8,
255+
Primitive::Int(Integer::I16, _) => types.u16,
256+
Primitive::Int(Integer::I32, _) => types.u32,
257+
Primitive::Int(Integer::I64, _) => types.u64,
258+
Primitive::Int(Integer::I128, _) => types.u128,
259+
Primitive::Pointer(_) => Ty::new_ptr(tcx, types.unit, Mutability::Not),
260+
Primitive::Float(Float::F16) => types.u16,
261+
Primitive::Float(Float::F32) => types.u32,
262+
Primitive::Float(Float::F64) => types.u64,
263+
Primitive::Float(Float::F128) => types.u128,
264+
}
265+
}
266+
267+
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size {
268+
self.value.size(&tcx)
269+
}
270+
271+
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value {
272+
use rustc_middle::mir::interpret::{Pointer, Scalar};
273+
let tcx = bx.tcx();
274+
let niche_ty = self.ty(tcx);
275+
let value = if niche_ty.is_any_ptr() {
276+
Scalar::from_maybe_pointer(Pointer::from_addr_invalid(val as u64), &tcx)
277+
} else {
278+
Scalar::from_uint(val, self.size(tcx))
279+
};
280+
let layout = rustc_target::abi::Scalar::Initialized {
281+
value: self.value,
282+
valid_range: WrappingRange::full(self.size(tcx)),
283+
};
284+
bx.scalar_to_backend(value, layout, bx.backend_type(bx.layout_of(self.ty(tcx))))
285+
}
286+
}

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
160160
}
161161
}
162162

163-
fn codegen_transmute(
163+
pub fn codegen_transmute(
164164
&mut self,
165165
bx: &mut Bx,
166166
src: OperandRef<'tcx, Bx::Value>,
@@ -195,7 +195,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
195195
///
196196
/// Returns `None` for cases that can't work in that framework, such as for
197197
/// `Immediate`->`Ref` that needs an `alloc` to get the location.
198-
fn codegen_transmute_operand(
198+
pub fn codegen_transmute_operand(
199199
&mut self,
200200
bx: &mut Bx,
201201
operand: OperandRef<'tcx, Bx::Value>,
@@ -337,7 +337,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
337337
///
338338
/// `to_backend_ty` must be the *non*-immediate backend type (so it will be
339339
/// `i8`, not `i1`, for `bool`-like types.)
340-
fn transmute_immediate(
340+
pub fn transmute_immediate(
341341
&self,
342342
bx: &mut Bx,
343343
mut imm: Bx::Value,

compiler/rustc_codegen_ssa/src/mir/statement.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
1+
use rustc_middle::mir::visit::Visitor;
12
use rustc_middle::mir::{self, NonDivergingIntrinsic};
23
use rustc_middle::span_bug;
34
use tracing::instrument;
45

56
use super::{FunctionCx, LocalRef};
7+
use crate::mir::niche_check::NicheFinder;
68
use crate::traits::*;
79

810
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
11+
fn niches_to_check(
12+
&mut self,
13+
bx: &mut Bx,
14+
statement: &mir::Statement<'tcx>,
15+
) -> Vec<(mir::Operand<'tcx>, rustc_target::abi::Niche)> {
16+
let mut finder = NicheFinder { fx: self, bx, places: Vec::new() };
17+
finder.visit_statement(statement, rustc_middle::mir::Location::START);
18+
finder.places
19+
}
20+
921
#[instrument(level = "debug", skip(self, bx))]
1022
pub(crate) fn codegen_statement(&mut self, bx: &mut Bx, statement: &mir::Statement<'tcx>) {
1123
self.set_debug_loc(bx, statement.source_info);
24+
25+
if bx.tcx().may_insert_niche_checks() {
26+
for (op, niche) in self.niches_to_check(bx, statement) {
27+
self.codegen_niche_check(bx, op, niche, statement.source_info);
28+
}
29+
}
30+
1231
match statement.kind {
1332
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
1433
if let Some(index) = place.as_local() {

compiler/rustc_hir/src/lang_items.rs

+6
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,12 @@ language_item_table! {
293293
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
294294
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
295295
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);
296+
PanicOccupiedNicheU8, sym::panic_occupied_niche_u8, panic_occupied_niche_u8, Target::Fn, GenericRequirement::None;
297+
PanicOccupiedNicheU16, sym::panic_occupied_niche_u16, panic_occupied_niche_u16, Target::Fn, GenericRequirement::None;
298+
PanicOccupiedNicheU32, sym::panic_occupied_niche_u32, panic_occupied_niche_u32, Target::Fn, GenericRequirement::None;
299+
PanicOccupiedNicheU64, sym::panic_occupied_niche_u64, panic_occupied_niche_u64, Target::Fn, GenericRequirement::None;
300+
PanicOccupiedNicheU128, sym::panic_occupied_niche_u128, panic_occupied_niche_u128, Target::Fn, GenericRequirement::None;
301+
PanicOccupiedNichePtr, sym::panic_occupied_niche_ptr, panic_occupied_niche_ptr, Target::Fn, GenericRequirement::None;
296302
PanicInfo, sym::panic_info, panic_info, Target::Struct, GenericRequirement::None;
297303
PanicLocation, sym::panic_location, panic_location, Target::Struct, GenericRequirement::None;
298304
PanicImpl, sym::panic_impl, panic_impl, Target::Fn, GenericRequirement::None;

0 commit comments

Comments
 (0)