Skip to content

Commit 9f84791

Browse files
author
Michael Benfield
committed
AddNicheCases MirPass
This pass optimizes switches on the discriminant of a niche-optimized enum.
1 parent d5139f4 commit 9f84791

File tree

40 files changed

+366
-124
lines changed

40 files changed

+366
-124
lines changed

compiler/rustc_borrowck/src/invalidation.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,10 @@ impl<'cx, 'tcx> InvalidationGenerator<'cx, 'tcx> {
291291
self.consume_operand(location, operand)
292292
}
293293

294-
Rvalue::Len(place) | Rvalue::Discriminant(place) => {
294+
Rvalue::Len(place) | Rvalue::Discriminant { place, .. } => {
295295
let af = match *rvalue {
296296
Rvalue::Len(..) => Some(ArtificialField::ArrayLength),
297-
Rvalue::Discriminant(..) => None,
297+
Rvalue::Discriminant { .. } => None,
298298
_ => unreachable!(),
299299
};
300300
self.access_place(

compiler/rustc_borrowck/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1280,10 +1280,10 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
12801280
self.consume_operand(location, (operand, span), flow_state)
12811281
}
12821282

1283-
Rvalue::Len(place) | Rvalue::Discriminant(place) => {
1283+
Rvalue::Len(place) | Rvalue::Discriminant { place, .. } => {
12841284
let af = match *rvalue {
12851285
Rvalue::Len(..) => Some(ArtificialField::ArrayLength),
1286-
Rvalue::Discriminant(..) => None,
1286+
Rvalue::Discriminant { .. } => None,
12871287
_ => unreachable!(),
12881288
};
12891289
self.access_place(

compiler/rustc_borrowck/src/type_check/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -2260,7 +2260,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
22602260
Rvalue::AddressOf(..)
22612261
| Rvalue::ThreadLocalRef(..)
22622262
| Rvalue::Len(..)
2263-
| Rvalue::Discriminant(..) => {}
2263+
| Rvalue::Discriminant { .. } => {}
22642264
}
22652265
}
22662266

@@ -2281,7 +2281,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
22812281
| Rvalue::CheckedBinaryOp(..)
22822282
| Rvalue::NullaryOp(..)
22832283
| Rvalue::UnaryOp(..)
2284-
| Rvalue::Discriminant(..) => None,
2284+
| Rvalue::Discriminant { .. } => None,
22852285

22862286
Rvalue::Aggregate(aggregate, _) => match **aggregate {
22872287
AggregateKind::Adt(_, _, _, user_ty, _) => user_ty,

compiler/rustc_codegen_cranelift/src/base.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ fn codegen_stmt<'tcx>(
634634
fx,
635635
operand,
636636
fx.layout_of(operand.layout().ty.discriminant_ty(fx.tcx)),
637+
false,
637638
)
638639
.load_scalar(fx);
639640

@@ -684,11 +685,15 @@ fn codegen_stmt<'tcx>(
684685
let operand = codegen_operand(fx, operand);
685686
operand.unsize_value(fx, lval);
686687
}
687-
Rvalue::Discriminant(place) => {
688+
Rvalue::Discriminant { place, relative } => {
688689
let place = codegen_place(fx, place);
689690
let value = place.to_cvalue(fx);
690-
let discr =
691-
crate::discriminant::codegen_get_discriminant(fx, value, dest_layout);
691+
let discr = crate::discriminant::codegen_get_discriminant(
692+
fx,
693+
value,
694+
dest_layout,
695+
relative,
696+
);
692697
lval.write_cvalue(fx, discr);
693698
}
694699
Rvalue::Repeat(ref operand, times) => {

compiler/rustc_codegen_cranelift/src/discriminant.rs

+40-29
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,13 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
6060
}
6161
}
6262

63+
/// If `relative` is true, we instead calculate the *relative* discriminant (see
64+
/// `RValue::Discriminant`).
6365
pub(crate) fn codegen_get_discriminant<'tcx>(
6466
fx: &mut FunctionCx<'_, '_, 'tcx>,
6567
value: CValue<'tcx>,
6668
dest_layout: TyAndLayout<'tcx>,
69+
relative: bool,
6770
) -> CValue<'tcx> {
6871
let layout = value.layout();
6972

@@ -131,38 +134,46 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
131134
// FIXME handle niche_start > i64::MAX
132135
fx.bcx.ins().iadd_imm(tag, -i64::try_from(niche_start).unwrap())
133136
};
134-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
135-
let is_niche = {
136-
codegen_icmp_imm(
137-
fx,
138-
IntCC::UnsignedLessThanOrEqual,
139-
relative_discr,
140-
i128::from(relative_max),
141-
)
142-
};
143137

144-
// NOTE(eddyb) this addition needs to be performed on the final
145-
// type, in case the niche itself can't represent all variant
146-
// indices (e.g. `u8` niche with more than `256` variants,
147-
// but enough uninhabited variants so that the remaining variants
148-
// fit in the niche).
149-
// In other words, `niche_variants.end - niche_variants.start`
150-
// is representable in the niche, but `niche_variants.end`
151-
// might not be, in extreme cases.
152-
let niche_discr = {
153-
let relative_discr = if relative_max == 0 {
154-
// HACK(eddyb) since we have only one niche, we know which
155-
// one it is, and we can avoid having a dynamic value here.
156-
fx.bcx.ins().iconst(cast_to, 0)
157-
} else {
158-
clif_intcast(fx, relative_discr, cast_to, false)
138+
if relative {
139+
CValue::by_val(clif_intcast(fx, relative_discr, cast_to, false), dest_layout)
140+
} else {
141+
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
142+
let is_niche = {
143+
codegen_icmp_imm(
144+
fx,
145+
IntCC::UnsignedLessThanOrEqual,
146+
relative_discr,
147+
i128::from(relative_max),
148+
)
159149
};
160-
fx.bcx.ins().iadd_imm(relative_discr, i64::from(niche_variants.start().as_u32()))
161-
};
162150

163-
let dataful_variant = fx.bcx.ins().iconst(cast_to, i64::from(dataful_variant.as_u32()));
164-
let discr = fx.bcx.ins().select(is_niche, niche_discr, dataful_variant);
165-
CValue::by_val(discr, dest_layout)
151+
// NOTE(eddyb) this addition needs to be performed on the final
152+
// type, in case the niche itself can't represent all variant
153+
// indices (e.g. `u8` niche with more than `256` variants,
154+
// but enough uninhabited variants so that the remaining variants
155+
// fit in the niche).
156+
// In other words, `niche_variants.end - niche_variants.start`
157+
// is representable in the niche, but `niche_variants.end`
158+
// might not be, in extreme cases.
159+
let niche_discr = {
160+
let relative_discr = if relative_max == 0 {
161+
// HACK(eddyb) since we have only one niche, we know which
162+
// one it is, and we can avoid having a dynamic value here.
163+
fx.bcx.ins().iconst(cast_to, 0)
164+
} else {
165+
clif_intcast(fx, relative_discr, cast_to, false)
166+
};
167+
fx.bcx
168+
.ins()
169+
.iadd_imm(relative_discr, i64::from(niche_variants.start().as_u32()))
170+
};
171+
172+
let dataful_variant =
173+
fx.bcx.ins().iconst(cast_to, i64::from(dataful_variant.as_u32()));
174+
let discr = fx.bcx.ins().select(is_niche, niche_discr, dataful_variant);
175+
CValue::by_val(discr, dest_layout)
176+
}
166177
}
167178
}
168179
}

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
359359

360360
sym::discriminant_value => {
361361
if ret_ty.is_integral() {
362-
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
362+
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty, false)
363363
} else {
364364
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
365365
}

compiler/rustc_codegen_ssa/src/mir/place.rs

+42-33
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,14 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
204204
}
205205

206206
/// Obtain the actual discriminant of a value.
207+
///
208+
/// If `relative` is true, instead calculate the *relative* discriminant (see
209+
/// `RValue::Discriminant`).
207210
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
208211
self,
209212
bx: &mut Bx,
210213
cast_to: Ty<'tcx>,
214+
relative: bool,
211215
) -> V {
212216
let cast_to = bx.cx().immediate_backend_type(bx.cx().layout_of(cast_to));
213217
if self.layout.abi.is_uninhabited() {
@@ -266,44 +270,49 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
266270
} else {
267271
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
268272
};
269-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
270-
let is_niche = if relative_max == 0 {
271-
// Avoid calling `const_uint`, which wouldn't work for pointers.
272-
// Also use canonical == 0 instead of non-canonical u<= 0.
273-
// FIXME(eddyb) check the actual primitive type here.
274-
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
273+
let relative_max = niche_variants.size_hint().1.unwrap() - 1;
274+
275+
if relative {
276+
bx.intcast(relative_discr, cast_to, false)
275277
} else {
276-
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
277-
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
278-
};
278+
// NOTE(eddyb) this addition needs to be performed on the final
279+
// type, in case the niche itself can't represent all variant
280+
// indices (e.g. `u8` niche with more than `256` variants,
281+
// but enough uninhabited variants so that the remaining variants
282+
// fit in the niche).
283+
// In other words, `niche_variants.end - niche_variants.start`
284+
// is representable in the niche, but `niche_variants.end`
285+
// might not be, in extreme cases.
286+
let niche_discr = {
287+
let relative_discr = if relative_max == 0 {
288+
// HACK(eddyb) since we have only one niche, we know which
289+
// one it is, and we can avoid having a dynamic value here.
290+
bx.cx().const_uint(cast_to, 0)
291+
} else {
292+
bx.intcast(relative_discr, cast_to, false)
293+
};
294+
bx.add(
295+
relative_discr,
296+
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
297+
)
298+
};
279299

280-
// NOTE(eddyb) this addition needs to be performed on the final
281-
// type, in case the niche itself can't represent all variant
282-
// indices (e.g. `u8` niche with more than `256` variants,
283-
// but enough uninhabited variants so that the remaining variants
284-
// fit in the niche).
285-
// In other words, `niche_variants.end - niche_variants.start`
286-
// is representable in the niche, but `niche_variants.end`
287-
// might not be, in extreme cases.
288-
let niche_discr = {
289-
let relative_discr = if relative_max == 0 {
290-
// HACK(eddyb) since we have only one niche, we know which
291-
// one it is, and we can avoid having a dynamic value here.
292-
bx.cx().const_uint(cast_to, 0)
300+
let is_niche = if relative_max == 0 {
301+
// Avoid calling `const_uint`, which wouldn't work for pointers.
302+
// Also use canonical == 0 instead of non-canonical u<= 0.
303+
// FIXME(eddyb) check the actual primitive type here.
304+
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
293305
} else {
294-
bx.intcast(relative_discr, cast_to, false)
306+
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
307+
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
295308
};
296-
bx.add(
297-
relative_discr,
298-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
299-
)
300-
};
301309

302-
bx.select(
303-
is_niche,
304-
niche_discr,
305-
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
306-
)
310+
bx.select(
311+
is_niche,
312+
niche_discr,
313+
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
314+
)
315+
}
307316
}
308317
}
309318
}

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
469469
(bx, OperandRef { val: OperandValue::Immediate(llval), layout: operand.layout })
470470
}
471471

472-
mir::Rvalue::Discriminant(ref place) => {
472+
mir::Rvalue::Discriminant { ref place, relative } => {
473473
let discr_ty = rvalue.ty(self.mir, bx.tcx());
474474
let discr_ty = self.monomorphize(discr_ty);
475475
let discr = self
476476
.codegen_place(&mut bx, place.as_ref())
477-
.codegen_get_discr(&mut bx, discr_ty);
477+
.codegen_get_discr(&mut bx, discr_ty, relative);
478478
(
479479
bx,
480480
OperandRef {
@@ -751,7 +751,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
751751
mir::Rvalue::BinaryOp(..) |
752752
mir::Rvalue::CheckedBinaryOp(..) |
753753
mir::Rvalue::UnaryOp(..) |
754-
mir::Rvalue::Discriminant(..) |
754+
mir::Rvalue::Discriminant { .. } |
755755
mir::Rvalue::NullaryOp(..) |
756756
mir::Rvalue::ThreadLocalRef(_) |
757757
mir::Rvalue::Use(..) => // (*)

compiler/rustc_const_eval/src/const_eval/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ fn const_to_valtree_inner<'tcx>(
109109
bug!("uninhabited types should have errored and never gotten converted to valtree")
110110
}
111111

112-
let variant = ecx.read_discriminant(&place.into()).unwrap().1;
112+
let variant = ecx.read_discriminant(&place.into(), false).unwrap().1;
113113

114114
branches(def.variant(variant).fields.len(), def.is_enum().then_some(variant))
115115
}
@@ -152,7 +152,7 @@ pub(crate) fn try_destructure_const<'tcx>(
152152
// index).
153153
ty::Adt(def, _) if def.variants().is_empty() => throw_ub!(Unreachable),
154154
ty::Adt(def, _) => {
155-
let variant = ecx.read_discriminant(&op)?.1;
155+
let variant = ecx.read_discriminant(&op, false)?.1;
156156
let down = ecx.operand_downcast(&op, variant)?;
157157
(def.variant(variant).fields.len(), Some(variant), down)
158158
}

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
230230
}
231231
sym::discriminant_value => {
232232
let place = self.deref_operand(&args[0])?;
233-
let discr_val = self.read_discriminant(&place.into())?.0;
233+
let discr_val = self.read_discriminant(&place.into(), false)?.0;
234234
self.write_scalar(discr_val, dest)?;
235235
}
236236
sym::unchecked_shl

compiler/rustc_const_eval/src/interpret/operand.rs

+20-8
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
631631

632632
/// Read discriminant, return the runtime value as well as the variant index.
633633
/// Can also legally be called on non-enums (e.g. through the discriminant_value intrinsic)!
634+
/// If `relative` is true, we instead calculate the *relative* discriminant
635+
/// (See the doc comment on `RValue::Discriminant`).
634636
pub fn read_discriminant(
635637
&self,
636638
op: &OpTy<'tcx, M::PointerTag>,
639+
relative: bool,
637640
) -> InterpResult<'tcx, (Scalar<M::PointerTag>, VariantIdx)> {
638641
trace!("read_discriminant_value {:#?}", op.layout);
639642
// Get type and layout of the discriminant.
@@ -722,7 +725,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
722725
// discriminant (encoded in niche/tag) and variant index are the same.
723726
let variants_start = niche_variants.start().as_u32();
724727
let variants_end = niche_variants.end().as_u32();
725-
let variant = match tag_val.try_to_int() {
728+
match tag_val.try_to_int() {
726729
Err(dbg_val) => {
727730
// So this is a pointer then, and casting to an int failed.
728731
// Can only happen during CTFE.
@@ -734,7 +737,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
734737
if !ptr_valid {
735738
throw_ub!(InvalidTag(dbg_val))
736739
}
737-
dataful_variant
740+
(
741+
Scalar::from_uint(dataful_variant.as_u32(), discr_layout.size),
742+
dataful_variant,
743+
)
738744
}
739745
Ok(tag_bits) => {
740746
let tag_bits = tag_bits.assert_bits(tag_layout.size);
@@ -748,7 +754,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
748754
.to_scalar()?
749755
.assert_bits(tag_val.layout.size);
750756
// Check if this is in the range that indicates an actual discriminant.
751-
if variant_index_relative <= u128::from(variants_end - variants_start) {
757+
let variant = if variant_index_relative
758+
<= u128::from(variants_end - variants_start)
759+
{
752760
let variant_index_relative = u32::try_from(variant_index_relative)
753761
.expect("we checked that this fits into a u32");
754762
// Then computing the absolute variant idx should not overflow any more.
@@ -766,13 +774,17 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
766774
VariantIdx::from_u32(variant_index)
767775
} else {
768776
dataful_variant
777+
};
778+
779+
// No need to cast, because the variant index directly serves as
780+
// discriminant and is encoded in the tag.
781+
if relative {
782+
(Scalar::from_uint(variant_index_relative, discr_layout.size), variant)
783+
} else {
784+
(Scalar::from_uint(variant.as_u32(), discr_layout.size), variant)
769785
}
770786
}
771-
};
772-
// Compute the size of the scalar we need to return.
773-
// No need to cast, because the variant index directly serves as discriminant and is
774-
// encoded in the tag.
775-
(Scalar::from_uint(variant.as_u32(), discr_layout.size), variant)
787+
}
776788
}
777789
})
778790
}

compiler/rustc_const_eval/src/interpret/step.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
297297
self.cast(&src, cast_kind, cast_ty, &dest)?;
298298
}
299299

300-
Discriminant(place) => {
300+
Discriminant { place, relative } => {
301301
let op = self.eval_place_to_op(place, None)?;
302-
let discr_val = self.read_discriminant(&op)?.0;
302+
let discr_val = self.read_discriminant(&op, relative)?.0;
303303
self.write_scalar(discr_val, &dest)?;
304304
}
305305
}

0 commit comments

Comments
 (0)