Skip to content

Don't alloca just to look at a discriminant #138391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
| PlaceContext::MutatingUse(MutatingUseContext::Retag) => {}

PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
NonMutatingUseContext::Copy
| NonMutatingUseContext::Move
// Inspect covers things like `PtrMetadata` and `Discriminant`
// which we can treat similar to `Copy` use for the purpose of
// whether we can use SSA variables for things.
| NonMutatingUseContext::Inspect,
) => match &mut self.locals[local] {
LocalKind::ZST => {}
LocalKind::Memory => {}
Expand All @@ -229,8 +234,7 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
| MutatingUseContext::Projection,
)
| PlaceContext::NonMutatingUse(
NonMutatingUseContext::Inspect
| NonMutatingUseContext::SharedBorrow
NonMutatingUseContext::SharedBorrow
| NonMutatingUseContext::FakeBorrow
| NonMutatingUseContext::RawBorrow
| NonMutatingUseContext::Projection,
Expand Down
10 changes: 1 addition & 9 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let callee_ty = instance.ty(bx.tcx(), bx.typing_env());

let ty::FnDef(def_id, fn_args) = *callee_ty.kind() else {
bug!("expected fn item type, found {}", callee_ty);
span_bug!(span, "expected fn item type, found {}", callee_ty);
};

let sig = callee_ty.fn_sig(bx.tcx());
Expand Down Expand Up @@ -325,14 +325,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}

sym::discriminant_value => {
if ret_ty.is_integral() {
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
} else {
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
}
}
Comment on lines -328 to -334
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annot: this is dead code, since LowerIntrinsics turns all the calls to this into MIR.

(And by deleting it there was only the one caller of codegen_get_discr to update.)


// This requires that atomic intrinsics follow a specific naming pattern:
// "atomic_<operation>[_<ordering>]"
name if let Some(atomic) = name_str.strip_prefix("atomic_") => {
Expand Down
148 changes: 146 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ use std::fmt;
use arrayvec::ArrayVec;
use either::Either;
use rustc_abi as abi;
use rustc_abi::{Align, BackendRepr, Size};
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
use rustc_middle::mir::{self, ConstValue};
use rustc_middle::ty::Ty;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::{bug, span_bug};
use tracing::debug;
use tracing::{debug, instrument};

use super::place::{PlaceRef, PlaceValue};
use super::{FunctionCx, LocalRef};
use crate::common::IntPredicate;
use crate::traits::*;
use crate::{MemFlags, size_of_val};

Expand Down Expand Up @@ -415,6 +416,149 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {

OperandRef { val, layout: field }
}

/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(fx, bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
fx: &mut FunctionCx<'a, 'tcx, Bx>,
bx: &mut Bx,
cast_to: Ty<'tcx>,
) -> V {
let dl = &bx.tcx().data_layout;
let cast_to_layout = bx.cx().layout_of(cast_to);
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);

// We check uninhabitedness separately because a type like
// `enum Foo { Bar(i32, !) }` is still reported as `Variants::Single`,
// *not* as `Variants::Empty`.
if self.layout.is_uninhabited() {
return bx.cx().const_poison(cast_to);
}

let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
Variants::Empty => unreachable!("we already handled uninhabited types"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Can we just move the poison return here, instead of doing the if? Or is there a case where is_uninhabited is true, but .variants is not Empty?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering that myself, TBH, but wasn't sure so left it as it was.

Exploring a bit more now (and thinking of the whole "is (i32, !) a ZST?" issue), it looks like yes it's possible: https://rust.godbolt.org/z/sjq1hEs8r

pub enum Foo<T> {
    Hmm(i32, T)
}
pub type Bar = Foo<std::convert::Infallible>;

That Bar has

           uninhabited: true,
           variants: Single {
               index: 0,
           },

so it's uninhabited but still single-variant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// A variant is absent if it's uninhabited and only has ZST fields.
// Present uninhabited variants only require space for their fields,
// but *not* an encoding of the discriminant (e.g., a tag value).
// See issue #49298 for more details on the need to leave space
// for non-ZST uninhabited data (mostly partial initialization).

Variants::Single { index } => {
let discr_val =
if let Some(discr) = self.layout.ty.discriminant_for_variant(bx.tcx(), index) {
discr.val
} else {
// This arm is for types which are neither enums nor coroutines,
// and thus for which the only possible "variant" should be the first one.
assert_eq!(index, FIRST_VARIANT);
// There's thus no actual discriminant to return, so we return
// what it would have been if this was a single-variant enum.
0
};
return bx.cx().const_uint_big(cast_to, discr_val);
}
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
(tag, tag_encoding, tag_field)
}
};

// Read the tag/niche-encoded discriminant from memory.
let tag_op = match self.val {
OperandValue::ZeroSized => bug!(),
OperandValue::Immediate(_) | OperandValue::Pair(_, _) => {
self.extract_field(fx, bx, tag_field)
}
OperandValue::Ref(place) => {
let tag = place.with_type(self.layout).project_field(bx, tag_field);
bx.load_operand(tag)
}
};
Comment on lines +461 to +470
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This codegen_get_discr method is moved nearly unchanged from PlaceRef, just now we have the extra case to extract_field if it's immediate(s). If it's a Ref, though, we do the same project_field as before.

let tag_imm = tag_op.immediate();

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
Primitive::Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag_imm, cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Cast to an integer so we don't have to treat a pointer as a
// special case.
let (tag, tag_llty) = match tag_scalar.primitive() {
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
Primitive::Pointer(_) => {
let t = bx.type_from_integer(dl.ptr_sized_integer());
let tag = bx.ptrtoint(tag_imm, t);
(tag, t)
}
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
};

let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};

let tagged_discr = if delta == 0 {
tagged_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};

let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);

// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this seems to be a pessimization.

discr
}
}
}
}

impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
Expand Down
124 changes: 0 additions & 124 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rustc_abi::Primitive::{Int, Pointer};
use rustc_abi::{Align, BackendRepr, FieldsShape, Size, TagEncoding, VariantIdx, Variants};
use rustc_middle::mir::PlaceTy;
use rustc_middle::mir::interpret::Scalar;
Expand Down Expand Up @@ -233,129 +232,6 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
val.with_type(field)
}

/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
bx: &mut Bx,
cast_to: Ty<'tcx>,
) -> V {
let dl = &bx.tcx().data_layout;
let cast_to_layout = bx.cx().layout_of(cast_to);
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
if self.layout.is_uninhabited() {
return bx.cx().const_poison(cast_to);
}
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
Variants::Empty => unreachable!("we already handled uninhabited types"),
Variants::Single { index } => {
let discr_val = self
.layout
.ty
.discriminant_for_variant(bx.cx().tcx(), index)
.map_or(index.as_u32() as u128, |discr| discr.val);
return bx.cx().const_uint_big(cast_to, discr_val);
}
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
(tag, tag_encoding, tag_field)
}
};

// Read the tag/niche-encoded discriminant from memory.
let tag = self.project_field(bx, tag_field);
let tag_op = bx.load_operand(tag);
let tag_imm = tag_op.immediate();

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag_imm, cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Cast to an integer so we don't have to treat a pointer as a
// special case.
let (tag, tag_llty) = match tag_scalar.primitive() {
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
Pointer(_) => {
let t = bx.type_from_integer(dl.ptr_sized_integer());
let tag = bx.ptrtoint(tag_imm, t);
(tag, t)
}
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
};

let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};

let tagged_discr = if delta == 0 {
tagged_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};

let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);

// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this seems to be a pessimization.

discr
}
}
}

/// Sets the discriminant for a new value of the given case of the given
/// representation.
pub fn codegen_set_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
mir::Rvalue::Discriminant(ref place) => {
let discr_ty = rvalue.ty(self.mir, bx.tcx());
let discr_ty = self.monomorphize(discr_ty);
let discr = self.codegen_place(bx, place.as_ref()).codegen_get_discr(bx, discr_ty);
let operand = self.codegen_consume(bx, place.as_ref());
let discr = operand.codegen_get_discr(self, bx, discr_ty);
OperandRef {
val: OperandValue::Immediate(discr),
layout: self.cx.layout_of(discr_ty),
Expand Down
Loading
Loading