Skip to content

AddNicheCases MirPass #95652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/rustc_borrowck/src/invalidation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ impl<'cx, 'tcx> InvalidationGenerator<'cx, 'tcx> {
self.consume_operand(location, operand)
}

Rvalue::Len(place) | Rvalue::Discriminant(place) => {
Rvalue::Len(place) | Rvalue::Discriminant { place, .. } => {
let af = match *rvalue {
Rvalue::Len(..) => Some(ArtificialField::ArrayLength),
Rvalue::Discriminant(..) => None,
Rvalue::Discriminant { .. } => None,
_ => unreachable!(),
};
self.access_place(
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_borrowck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1280,10 +1280,10 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
self.consume_operand(location, (operand, span), flow_state)
}

Rvalue::Len(place) | Rvalue::Discriminant(place) => {
Rvalue::Len(place) | Rvalue::Discriminant { place, .. } => {
let af = match *rvalue {
Rvalue::Len(..) => Some(ArtificialField::ArrayLength),
Rvalue::Discriminant(..) => None,
Rvalue::Discriminant { .. } => None,
_ => unreachable!(),
};
self.access_place(
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2260,7 +2260,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
Rvalue::AddressOf(..)
| Rvalue::ThreadLocalRef(..)
| Rvalue::Len(..)
| Rvalue::Discriminant(..) => {}
| Rvalue::Discriminant { .. } => {}
}
}

Expand All @@ -2281,7 +2281,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
| Rvalue::CheckedBinaryOp(..)
| Rvalue::NullaryOp(..)
| Rvalue::UnaryOp(..)
| Rvalue::Discriminant(..) => None,
| Rvalue::Discriminant { .. } => None,

Rvalue::Aggregate(aggregate, _) => match **aggregate {
AggregateKind::Adt(_, _, _, user_ty, _) => user_ty,
Expand Down
11 changes: 8 additions & 3 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ fn codegen_stmt<'tcx>(
fx,
operand,
fx.layout_of(operand.layout().ty.discriminant_ty(fx.tcx)),
false,
)
.load_scalar(fx);

Expand Down Expand Up @@ -684,11 +685,15 @@ fn codegen_stmt<'tcx>(
let operand = codegen_operand(fx, operand);
operand.unsize_value(fx, lval);
}
Rvalue::Discriminant(place) => {
Rvalue::Discriminant { place, relative } => {
let place = codegen_place(fx, place);
let value = place.to_cvalue(fx);
let discr =
crate::discriminant::codegen_get_discriminant(fx, value, dest_layout);
let discr = crate::discriminant::codegen_get_discriminant(
fx,
value,
dest_layout,
relative,
);
lval.write_cvalue(fx, discr);
}
Rvalue::Repeat(ref operand, times) => {
Expand Down
69 changes: 40 additions & 29 deletions compiler/rustc_codegen_cranelift/src/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
}
}

/// If `relative` is true, we instead calculate the *relative* discriminant (see
/// `RValue::Discriminant`).
pub(crate) fn codegen_get_discriminant<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
value: CValue<'tcx>,
dest_layout: TyAndLayout<'tcx>,
relative: bool,
) -> CValue<'tcx> {
let layout = value.layout();

Expand Down Expand Up @@ -131,38 +134,46 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
// FIXME handle niche_start > i64::MAX
fx.bcx.ins().iadd_imm(tag, -i64::try_from(niche_start).unwrap())
};
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = {
codegen_icmp_imm(
fx,
IntCC::UnsignedLessThanOrEqual,
relative_discr,
i128::from(relative_max),
)
};

// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
// indices (e.g. `u8` niche with more than `256` variants,
// but enough uninhabited variants so that the remaining variants
// fit in the niche).
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
fx.bcx.ins().iconst(cast_to, 0)
} else {
clif_intcast(fx, relative_discr, cast_to, false)
if relative {
CValue::by_val(clif_intcast(fx, relative_discr, cast_to, false), dest_layout)
} else {
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = {
codegen_icmp_imm(
fx,
IntCC::UnsignedLessThanOrEqual,
relative_discr,
i128::from(relative_max),
)
};
fx.bcx.ins().iadd_imm(relative_discr, i64::from(niche_variants.start().as_u32()))
};

let dataful_variant = fx.bcx.ins().iconst(cast_to, i64::from(dataful_variant.as_u32()));
let discr = fx.bcx.ins().select(is_niche, niche_discr, dataful_variant);
CValue::by_val(discr, dest_layout)
// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
// indices (e.g. `u8` niche with more than `256` variants,
// but enough uninhabited variants so that the remaining variants
// fit in the niche).
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
fx.bcx.ins().iconst(cast_to, 0)
} else {
clif_intcast(fx, relative_discr, cast_to, false)
};
fx.bcx
.ins()
.iadd_imm(relative_discr, i64::from(niche_variants.start().as_u32()))
};

let dataful_variant =
fx.bcx.ins().iconst(cast_to, i64::from(dataful_variant.as_u32()));
let discr = fx.bcx.ins().select(is_niche, niche_discr, dataful_variant);
CValue::by_val(discr, dest_layout)
}
}
}
}
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {

sym::discriminant_value => {
if ret_ty.is_integral() {
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty, false)
} else {
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
}
Expand Down
75 changes: 42 additions & 33 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}

/// Obtain the actual discriminant of a value.
///
/// If `relative` is true, instead calculate the *relative* discriminant (see
/// `RValue::Discriminant`).
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
bx: &mut Bx,
cast_to: Ty<'tcx>,
relative: bool,
) -> V {
let cast_to = bx.cx().immediate_backend_type(bx.cx().layout_of(cast_to));
if self.layout.abi.is_uninhabited() {
Expand Down Expand Up @@ -266,44 +270,49 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
} else {
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
};
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
let relative_max = niche_variants.size_hint().1.unwrap() - 1;

if relative {
bx.intcast(relative_discr, cast_to, false)
} else {
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
};
// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
// indices (e.g. `u8` niche with more than `256` variants,
// but enough uninhabited variants so that the remaining variants
// fit in the niche).
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
bx.cx().const_uint(cast_to, 0)
} else {
bx.intcast(relative_discr, cast_to, false)
};
bx.add(
relative_discr,
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
)
};

// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
// indices (e.g. `u8` niche with more than `256` variants,
// but enough uninhabited variants so that the remaining variants
// fit in the niche).
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
bx.cx().const_uint(cast_to, 0)
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
} else {
bx.intcast(relative_discr, cast_to, false)
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
};
bx.add(
relative_discr,
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
)
};

bx.select(
is_niche,
niche_discr,
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
)
bx.select(
is_niche,
niche_discr,
bx.cx().const_uint(cast_to, dataful_variant.as_u32() as u64),
)
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
(bx, OperandRef { val: OperandValue::Immediate(llval), layout: operand.layout })
}

mir::Rvalue::Discriminant(ref place) => {
mir::Rvalue::Discriminant { ref place, relative } => {
let discr_ty = rvalue.ty(self.mir, bx.tcx());
let discr_ty = self.monomorphize(discr_ty);
let discr = self
.codegen_place(&mut bx, place.as_ref())
.codegen_get_discr(&mut bx, discr_ty);
.codegen_get_discr(&mut bx, discr_ty, relative);
(
bx,
OperandRef {
Expand Down Expand Up @@ -751,7 +751,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
mir::Rvalue::BinaryOp(..) |
mir::Rvalue::CheckedBinaryOp(..) |
mir::Rvalue::UnaryOp(..) |
mir::Rvalue::Discriminant(..) |
mir::Rvalue::Discriminant { .. } |
mir::Rvalue::NullaryOp(..) |
mir::Rvalue::ThreadLocalRef(_) |
mir::Rvalue::Use(..) => // (*)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/const_eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ fn const_to_valtree_inner<'tcx>(
bug!("uninhabited types should have errored and never gotten converted to valtree")
}

let variant = ecx.read_discriminant(&place.into()).unwrap().1;
let variant = ecx.read_discriminant(&place.into(), false).unwrap().1;

branches(def.variant(variant).fields.len(), def.is_enum().then_some(variant))
}
Expand Down Expand Up @@ -152,7 +152,7 @@ pub(crate) fn try_destructure_const<'tcx>(
// index).
ty::Adt(def, _) if def.variants().is_empty() => throw_ub!(Unreachable),
ty::Adt(def, _) => {
let variant = ecx.read_discriminant(&op)?.1;
let variant = ecx.read_discriminant(&op, false)?.1;
let down = ecx.operand_downcast(&op, variant)?;
(def.variant(variant).fields.len(), Some(variant), down)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}
sym::discriminant_value => {
let place = self.deref_operand(&args[0])?;
let discr_val = self.read_discriminant(&place.into())?.0;
let discr_val = self.read_discriminant(&place.into(), false)?.0;
self.write_scalar(discr_val, dest)?;
}
sym::unchecked_shl
Expand Down
28 changes: 20 additions & 8 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

/// Read discriminant, return the runtime value as well as the variant index.
/// Can also legally be called on non-enums (e.g. through the discriminant_value intrinsic)!
/// If `relative` is true, we instead calculate the *relative* discriminant
/// (See the doc comment on `RValue::Discriminant`).
pub fn read_discriminant(
&self,
op: &OpTy<'tcx, M::PointerTag>,
relative: bool,
) -> InterpResult<'tcx, (Scalar<M::PointerTag>, VariantIdx)> {
trace!("read_discriminant_value {:#?}", op.layout);
// Get type and layout of the discriminant.
Expand Down Expand Up @@ -722,7 +725,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// discriminant (encoded in niche/tag) and variant index are the same.
let variants_start = niche_variants.start().as_u32();
let variants_end = niche_variants.end().as_u32();
let variant = match tag_val.try_to_int() {
match tag_val.try_to_int() {
Err(dbg_val) => {
// So this is a pointer then, and casting to an int failed.
// Can only happen during CTFE.
Expand All @@ -734,7 +737,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
if !ptr_valid {
throw_ub!(InvalidTag(dbg_val))
}
dataful_variant
(
Scalar::from_uint(dataful_variant.as_u32(), discr_layout.size),
dataful_variant,
)
}
Ok(tag_bits) => {
let tag_bits = tag_bits.assert_bits(tag_layout.size);
Expand All @@ -748,7 +754,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
.to_scalar()?
.assert_bits(tag_val.layout.size);
// Check if this is in the range that indicates an actual discriminant.
if variant_index_relative <= u128::from(variants_end - variants_start) {
let variant = if variant_index_relative
<= u128::from(variants_end - variants_start)
{
let variant_index_relative = u32::try_from(variant_index_relative)
.expect("we checked that this fits into a u32");
// Then computing the absolute variant idx should not overflow any more.
Expand All @@ -766,13 +774,17 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
VariantIdx::from_u32(variant_index)
} else {
dataful_variant
};

// No need to cast, because the variant index directly serves as
// discriminant and is encoded in the tag.
if relative {
(Scalar::from_uint(variant_index_relative, discr_layout.size), variant)
} else {
(Scalar::from_uint(variant.as_u32(), discr_layout.size), variant)
}
}
};
// Compute the size of the scalar we need to return.
// No need to cast, because the variant index directly serves as discriminant and is
// encoded in the tag.
(Scalar::from_uint(variant.as_u32(), discr_layout.size), variant)
}
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
self.cast(&src, cast_kind, cast_ty, &dest)?;
}

Discriminant(place) => {
Discriminant { place, relative } => {
let op = self.eval_place_to_op(place, None)?;
let discr_val = self.read_discriminant(&op)?.0;
let discr_val = self.read_discriminant(&op, relative)?.0;
self.write_scalar(discr_val, &dest)?;
}
}
Expand Down
Loading