Skip to content

Stop generating allocas & memcmp for simple short array equality #85828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 9, 2021
34 changes: 34 additions & 0 deletions compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,40 @@ pub(crate) fn codegen_intrinsic_call<'tcx>(
);
ret.write_cvalue(fx, CValue::by_val(res, ret.layout()));
};

raw_eq, <T>(v lhs_ref, v rhs_ref) {
fn type_by_size(size: Size) -> Option<Type> {
Type::int(size.bits().try_into().ok()?)
}

let size = fx.layout_of(T).layout.size;
let is_eq_value =
if size == Size::ZERO {
// No bytes means they're trivially equal
fx.bcx.ins().iconst(types::I8, 1)
} else if let Some(clty) = type_by_size(size) {
// Can't use `trusted` for these loads; they could be unaligned.
let mut flags = MemFlags::new();
flags.set_notrap();
let lhs_val = fx.bcx.ins().load(clty, flags, lhs_ref, 0);
let rhs_val = fx.bcx.ins().load(clty, flags, rhs_ref, 0);
let eq = fx.bcx.ins().icmp(IntCC::Equal, lhs_val, rhs_val);
fx.bcx.ins().bint(types::I8, eq)
} else {
// Just call `memcmp` (like slices do in core) when the
// size is too large or it's not a power-of-two.
let ptr_ty = pointer_ty(fx.tcx);
let signed_bytes = i64::try_from(size.bytes()).unwrap();
let bytes_val = fx.bcx.ins().iconst(ptr_ty, signed_bytes);
let params = vec![AbiParam::new(ptr_ty); 3];
let returns = vec![AbiParam::new(types::I32)];
let args = &[lhs_ref, rhs_ref, bytes_val];
let cmp = fx.lib_call("memcmp", params, returns, args)[0];
let eq = fx.bcx.ins().icmp_imm(IntCC::Equal, cmp, 0);
fx.bcx.ins().bint(types::I8, eq)
};
ret.write_cvalue(fx, CValue::by_val(is_eq_value, ret.layout()));
};
}

if let Some((_, dest)) = destination {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_cranelift/src/value_and_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ impl<'tcx> CPlace<'tcx> {
ptr.store(fx, data, MemFlags::trusted());
ptr.load(fx, dst_ty, MemFlags::trusted())
}

// `CValue`s should never contain SSA-only types, so if you ended
// up here having seen an error like `B1 -> I8`, then before
// calling `write_cvalue` you need to add a `bint` instruction.
_ => unreachable!("write_cvalue_transmute: {:?} -> {:?}", src_ty, dst_ty),
};
//fx.bcx.set_val_label(data, cranelift_codegen::ir::ValueLabel::new(var.index()));
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ impl CodegenCx<'b, 'tcx> {
let t_i32 = self.type_i32();
let t_i64 = self.type_i64();
let t_i128 = self.type_i128();
let t_isize = self.type_isize();
let t_f32 = self.type_f32();
let t_f64 = self.type_f64();

Expand Down Expand Up @@ -712,6 +713,10 @@ impl CodegenCx<'b, 'tcx> {
ifn!("llvm.assume", fn(i1) -> void);
ifn!("llvm.prefetch", fn(i8p, t_i32, t_i32, t_i32) -> void);

// This isn't an "LLVM intrinsic", but LLVM's optimization passes
// recognize it like one and we assume it exists in `core::slice::cmp`
ifn!("memcmp", fn(i8p, i8p, t_isize) -> t_i32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memcmp in slice::cmp has a FIXME for the return type, should that go here too?


// variadic intrinsics
ifn!("llvm.va_start", fn(i8p) -> void);
ifn!("llvm.va_end", fn(i8p) -> void);
Expand Down
38 changes: 38 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,44 @@ impl IntrinsicCallMethods<'tcx> for Builder<'a, 'll, 'tcx> {
}
}

sym::raw_eq => {
use abi::Abi::*;
let tp_ty = substs.type_at(0);
let layout = self.layout_of(tp_ty).layout;
let use_integer_compare = match layout.abi {
Scalar(_) | ScalarPair(_, _) => true,
Uninhabited | Vector { .. } => false,
Aggregate { .. } => {
// For rusty ABIs, small aggregates are actually passed
// as `RegKind::Integer` (see `FnAbi::adjust_for_abi`),
// so we re-use that same threshold here.
layout.size <= self.data_layout().pointer_size * 2
}
};

let a = args[0].immediate();
let b = args[1].immediate();
if layout.size.bytes() == 0 {
self.const_bool(true)
} else if use_integer_compare {
let integer_ty = self.type_ix(layout.size.bits());
let ptr_ty = self.type_ptr_to(integer_ty);
let a_ptr = self.bitcast(a, ptr_ty);
let a_val = self.load(a_ptr, layout.align.abi);
let b_ptr = self.bitcast(b, ptr_ty);
let b_val = self.load(b_ptr, layout.align.abi);
self.icmp(IntPredicate::IntEQ, a_val, b_val)
} else {
let i8p_ty = self.type_i8p();
let a_ptr = self.bitcast(a, i8p_ty);
let b_ptr = self.bitcast(b, i8p_ty);
let n = self.const_usize(layout.size.bytes());
let llfn = self.get_intrinsic("memcmp");
let cmp = self.call(llfn, &[a_ptr, b_ptr, n], None);
self.icmp(IntPredicate::IntEQ, cmp, self.const_i32(0))
}
}

_ if name_str.starts_with("simd_") => {
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
Ok(llval) => llval,
Expand Down
19 changes: 19 additions & 0 deletions compiler/rustc_mir/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
throw_ub_format!("`assume` intrinsic called with `false`");
}
}
sym::raw_eq => {
let result = self.raw_eq_intrinsic(&args[0], &args[1])?;
self.write_scalar(result, dest)?;
}
_ => return Ok(false),
}

Expand Down Expand Up @@ -559,4 +563,19 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

self.memory.copy(src, align, dst, align, size, nonoverlapping)
}

pub(crate) fn raw_eq_intrinsic(
&mut self,
lhs: &OpTy<'tcx, <M as Machine<'mir, 'tcx>>::PointerTag>,
rhs: &OpTy<'tcx, <M as Machine<'mir, 'tcx>>::PointerTag>,
) -> InterpResult<'tcx, Scalar<M::PointerTag>> {
let layout = self.layout_of(lhs.layout.ty.builtin_deref(true).unwrap().ty)?;
assert!(!layout.is_unsized());

let lhs = self.read_scalar(lhs)?.check_init()?;
let rhs = self.read_scalar(rhs)?.check_init()?;
let lhs_bytes = self.memory.read_bytes(lhs, layout.size)?;
let rhs_bytes = self.memory.read_bytes(rhs, layout.size)?;
Ok(Scalar::from_bool(lhs_bytes == rhs_bytes))
}
}
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ symbols! {
quote,
range_inclusive_new,
raw_dylib,
raw_eq,
raw_identifiers,
raw_ref_op,
re_rebalance_coherence,
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_typeck/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {

sym::nontemporal_store => (1, vec![tcx.mk_mut_ptr(param(0)), param(0)], tcx.mk_unit()),

sym::raw_eq => {
let br = ty::BoundRegion { var: ty::BoundVar::from_u32(0), kind: ty::BrAnon(0) };
let param_ty =
tcx.mk_imm_ref(tcx.mk_region(ty::ReLateBound(ty::INNERMOST, br)), param(0));
(1, vec![param_ty; 2], tcx.types.bool)
}

other => {
tcx.sess.emit_err(UnrecognizedIntrinsicFunction { span: it.span, name: other });
return;
Expand Down
160 changes: 160 additions & 0 deletions library/core/src/array/equality.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<[B; N]> for [A; N]
where
A: PartialEq<B>,
{
#[inline]
fn eq(&self, other: &[B; N]) -> bool {
SpecArrayEq::spec_eq(self, other)
}
#[inline]
fn ne(&self, other: &[B; N]) -> bool {
SpecArrayEq::spec_ne(self, other)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<[B]> for [A; N]
where
A: PartialEq<B>,
{
#[inline]
fn eq(&self, other: &[B]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &[B]) -> bool {
self[..] != other[..]
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<[A; N]> for [B]
where
B: PartialEq<A>,
{
#[inline]
fn eq(&self, other: &[A; N]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &[A; N]) -> bool {
self[..] != other[..]
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<&[B]> for [A; N]
where
A: PartialEq<B>,
{
#[inline]
fn eq(&self, other: &&[B]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &&[B]) -> bool {
self[..] != other[..]
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<[A; N]> for &[B]
where
B: PartialEq<A>,
{
#[inline]
fn eq(&self, other: &[A; N]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &[A; N]) -> bool {
self[..] != other[..]
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<&mut [B]> for [A; N]
where
A: PartialEq<B>,
{
#[inline]
fn eq(&self, other: &&mut [B]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &&mut [B]) -> bool {
self[..] != other[..]
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<A, B, const N: usize> PartialEq<[A; N]> for &mut [B]
where
B: PartialEq<A>,
{
#[inline]
fn eq(&self, other: &[A; N]) -> bool {
self[..] == other[..]
}
#[inline]
fn ne(&self, other: &[A; N]) -> bool {
self[..] != other[..]
}
}

// NOTE: some less important impls are omitted to reduce code bloat
// __impl_slice_eq2! { [A; $N], &'b [B; $N] }
// __impl_slice_eq2! { [A; $N], &'b mut [B; $N] }

#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq, const N: usize> Eq for [T; N] {}

trait SpecArrayEq<Other, const N: usize>: Sized {
fn spec_eq(a: &[Self; N], b: &[Other; N]) -> bool;
fn spec_ne(a: &[Self; N], b: &[Other; N]) -> bool;
}

impl<T: PartialEq<Other>, Other, const N: usize> SpecArrayEq<Other, N> for T {
default fn spec_eq(a: &[Self; N], b: &[Other; N]) -> bool {
a[..] == b[..]
}
default fn spec_ne(a: &[Self; N], b: &[Other; N]) -> bool {
a[..] != b[..]
}
}

impl<T: PartialEq<U> + IsRawEqComparable<U>, U, const N: usize> SpecArrayEq<U, N> for T {
#[cfg(bootstrap)]
fn spec_eq(a: &[T; N], b: &[U; N]) -> bool {
a[..] == b[..]
}
#[cfg(not(bootstrap))]
fn spec_eq(a: &[T; N], b: &[U; N]) -> bool {
// SAFETY: This is why `IsRawEqComparable` is an `unsafe trait`.
unsafe {
let b = &*b.as_ptr().cast::<[T; N]>();
crate::intrinsics::raw_eq(a, b)
}
}
fn spec_ne(a: &[T; N], b: &[U; N]) -> bool {
!Self::spec_eq(a, b)
}
}

/// `U` exists on here mostly because `min_specialization` didn't let me
/// repeat the `T` type parameter in the above specialization, so instead
/// the `T == U` constraint comes from the impls on this.
/// # Safety
/// - Neither `Self` nor `U` has any padding.
/// - `Self` and `U` have the same layout.
/// - `Self: PartialEq<U>` is byte-wise (this means no floats, among other things)
#[rustc_specialization_trait]
unsafe trait IsRawEqComparable<U> {}

macro_rules! is_raw_comparable {
($($t:ty),+) => {$(
unsafe impl IsRawEqComparable<$t> for $t {}
)+};
}
is_raw_comparable!(bool, char, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize);
Loading