Skip to content

Enable GVN for AggregateKind::RawPtr #125041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 87 additions & 12 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
//! that contain `AllocId`s.

use rustc_const_eval::const_eval::DummyMachine;
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar};
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemPlaceMeta, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable, Scalar};
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_hir::def::DefKind;
Expand All @@ -99,7 +99,7 @@ use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
use rustc_target::abi::{self, Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
use smallvec::SmallVec;
use std::borrow::Cow;

Expand Down Expand Up @@ -177,6 +177,12 @@ enum AggregateTy<'tcx> {
Array,
Tuple,
Def(DefId, ty::GenericArgsRef<'tcx>),
RawPtr {
/// Needed for cast propagation.
data_pointer_ty: Ty<'tcx>,
/// The data pointer can be anything thin, so doesn't determine the output.
output_pointer_ty: Ty<'tcx>,
},
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -385,11 +391,22 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
AggregateTy::Def(def_id, args) => {
self.tcx.type_of(def_id).instantiate(self.tcx, args)
}
AggregateTy::RawPtr { output_pointer_ty, .. } => output_pointer_ty,
};
let variant = if ty.is_enum() { Some(variant) } else { None };
let ty = self.ecx.layout_of(ty).ok()?;
if ty.is_zst() {
ImmTy::uninit(ty).into()
} else if matches!(kind, AggregateTy::RawPtr { .. }) {
// Pointers don't have fields, so don't `project_field` them.
let data = self.ecx.read_pointer(fields[0]).ok()?;
let meta = if fields[1].layout.is_zst() {
MemPlaceMeta::None
} else {
MemPlaceMeta::Meta(self.ecx.read_scalar(fields[1]).ok()?)
};
let ptr_imm = Immediate::new_pointer_with_meta(data, meta, &self.ecx);
ImmTy::from_immediate(ptr_imm, ty).into()
} else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?;
let variant_dest = if let Some(variant) = variant {
Expand Down Expand Up @@ -862,10 +879,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
rvalue: &mut Rvalue<'tcx>,
location: Location,
) -> Option<VnIndex> {
let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() };
let Rvalue::Aggregate(box ref kind, ref mut field_ops) = *rvalue else { bug!() };

let tcx = self.tcx;
if fields.is_empty() {
if field_ops.is_empty() {
let is_zst = match *kind {
AggregateKind::Array(..)
| AggregateKind::Tuple
Expand All @@ -884,13 +901,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}

let (ty, variant_index) = match *kind {
let (mut ty, variant_index) = match *kind {
AggregateKind::Array(..) => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Array, FIRST_VARIANT)
}
AggregateKind::Tuple => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Tuple, FIRST_VARIANT)
}
AggregateKind::Closure(did, args)
Expand All @@ -901,15 +918,49 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
// Do not track unions.
AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
// FIXME: Do the extra work to GVN `from_raw_parts`
AggregateKind::RawPtr(..) => return None,
AggregateKind::RawPtr(pointee_ty, mtbl) => {
assert_eq!(field_ops.len(), 2);
let data_pointer_ty = field_ops[FieldIdx::ZERO].ty(self.local_decls, self.tcx);
let output_pointer_ty = Ty::new_ptr(self.tcx, pointee_ty, mtbl);
(AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty }, FIRST_VARIANT)
}
};

let fields: Option<Vec<_>> = fields
let fields: Option<Vec<_>> = field_ops
.iter_mut()
.map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
.collect();
let fields = fields?;
let mut fields = fields?;

if let AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty } = &mut ty {
let mut was_updated = false;

// Any thin pointer of matching mutability is fine as the data pointer.
while let Value::Cast {
kind: CastKind::PtrToPtr,
value: cast_value,
from: cast_from,
to: _,
} = self.get(fields[0])
&& let ty::RawPtr(from_pointee_ty, from_mtbl) = cast_from.kind()
&& let ty::RawPtr(_, output_mtbl) = output_pointer_ty.kind()
&& from_mtbl == output_mtbl
&& from_pointee_ty.is_sized(self.tcx, self.param_env)
{
fields[0] = *cast_value;
*data_pointer_ty = *cast_from;
was_updated = true;
}

if was_updated {
if let Some(const_) = self.try_as_constant(fields[0]) {
field_ops[FieldIdx::ZERO] = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(fields[0], location) {
field_ops[FieldIdx::ZERO] = Operand::Copy(Place::from(local));
self.reused_locals.insert(local);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually check both constants and locals, to prefer keeping a constant when one is available. Or is it too unlikely here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I might as well for consistency. It's probably quite uncommon, though, because this is the data pointer, and those usually aren't constants. But I suppose for ZSTs it's possible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...ok, done.

}

if let AggregateTy::Array = ty
&& fields.len() > 4
Expand Down Expand Up @@ -941,6 +992,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
(UnOp::Not, Value::BinaryOp(BinOp::Ne, lhs, rhs)) => {
Value::BinaryOp(BinOp::Eq, *lhs, *rhs)
}
(UnOp::PtrMetadata, Value::Aggregate(AggregateTy::RawPtr { .. }, _, fields)) => {
return Some(fields[1]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add such a line to Rvalue::Len too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit more complex for Len, because Len has a deref as well.

I think my preference be would to work towards removing Len, and not add a bunch more cases which we'd just delete later. Not having this for Len isn't a regression, after all, so I think it's fine without it.

_ => return None,
};

Expand Down Expand Up @@ -1092,6 +1146,23 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
return self.new_opaque();
}

let mut was_updated = false;

// If that cast just casts away the metadata again,
if let PtrToPtr = kind
&& let Value::Aggregate(AggregateTy::RawPtr { data_pointer_ty, .. }, _, fields) =
self.get(value)
&& let ty::RawPtr(to_pointee, _) = to.kind()
&& to_pointee.is_sized(self.tcx, self.param_env)
{
from = *data_pointer_ty;
value = fields[0];
was_updated = true;
if *data_pointer_ty == to {
return Some(fields[0]);
}
}

if let PtrToPtr | PointerCoercion(MutToConstPointer) = kind
&& let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } =
*self.get(value)
Expand All @@ -1100,9 +1171,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
from = inner_from;
value = inner_value;
*kind = PtrToPtr;
was_updated = true;
if inner_from == to {
return Some(inner_value);
}
}

if was_updated {
if let Some(const_) = self.try_as_constant(value) {
*operand = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(value, location) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
- // MIR for `casts_before_aggregate_raw_ptr` before GVN
+ // MIR for `casts_before_aggregate_raw_ptr` after GVN

fn casts_before_aggregate_raw_ptr(_1: *const u32) -> *const [u8] {
debug x => _1;
let mut _0: *const [u8];
let _2: *const [u8; 4];
let mut _3: *const u32;
let mut _5: *const [u8; 4];
let mut _7: *const u8;
let mut _8: *const ();
scope 1 {
debug x => _2;
let _4: *const u8;
scope 2 {
debug x => _4;
let _6: *const ();
scope 3 {
debug x => _6;
}
}
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = move _3 as *const [u8; 4] (PtrToPtr);
+ _2 = _1 as *const [u8; 4] (PtrToPtr);
StorageDead(_3);
- StorageLive(_4);
+ nop;
StorageLive(_5);
_5 = _2;
- _4 = move _5 as *const u8 (PtrToPtr);
+ _4 = _1 as *const u8 (PtrToPtr);
StorageDead(_5);
- StorageLive(_6);
+ nop;
StorageLive(_7);
_7 = _4;
- _6 = move _7 as *const () (PtrToPtr);
+ _6 = _1 as *const () (PtrToPtr);
StorageDead(_7);
StorageLive(_8);
_8 = _6;
- _0 = *const [u8] from (move _8, const 4_usize);
+ _0 = *const [u8] from (_1, const 4_usize);
StorageDead(_8);
- StorageDead(_6);
- StorageDead(_4);
- StorageDead(_2);
+ nop;
+ nop;
+ nop;
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
- // MIR for `casts_before_aggregate_raw_ptr` before GVN
+ // MIR for `casts_before_aggregate_raw_ptr` after GVN

fn casts_before_aggregate_raw_ptr(_1: *const u32) -> *const [u8] {
debug x => _1;
let mut _0: *const [u8];
let _2: *const [u8; 4];
let mut _3: *const u32;
let mut _5: *const [u8; 4];
let mut _7: *const u8;
let mut _8: *const ();
scope 1 {
debug x => _2;
let _4: *const u8;
scope 2 {
debug x => _4;
let _6: *const ();
scope 3 {
debug x => _6;
}
}
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = move _3 as *const [u8; 4] (PtrToPtr);
+ _2 = _1 as *const [u8; 4] (PtrToPtr);
StorageDead(_3);
- StorageLive(_4);
+ nop;
StorageLive(_5);
_5 = _2;
- _4 = move _5 as *const u8 (PtrToPtr);
+ _4 = _1 as *const u8 (PtrToPtr);
StorageDead(_5);
- StorageLive(_6);
+ nop;
StorageLive(_7);
_7 = _4;
- _6 = move _7 as *const () (PtrToPtr);
+ _6 = _1 as *const () (PtrToPtr);
StorageDead(_7);
StorageLive(_8);
_8 = _6;
- _0 = *const [u8] from (move _8, const 4_usize);
+ _0 = *const [u8] from (_1, const 4_usize);
StorageDead(_8);
- StorageDead(_6);
- StorageDead(_4);
- StorageDead(_2);
+ nop;
+ nop;
+ nop;
return;
}
}

32 changes: 32 additions & 0 deletions tests/mir-opt/gvn.meta_of_ref_to_slice.GVN.panic-abort.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `meta_of_ref_to_slice` before GVN
+ // MIR for `meta_of_ref_to_slice` after GVN

fn meta_of_ref_to_slice(_1: *const i32) -> usize {
debug x => _1;
let mut _0: usize;
let _2: *const [i32];
let mut _3: *const i32;
let mut _4: *const [i32];
scope 1 {
debug ptr => _2;
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = *const [i32] from (move _3, const 1_usize);
+ _2 = *const [i32] from (_1, const 1_usize);
StorageDead(_3);
StorageLive(_4);
_4 = _2;
- _0 = PtrMetadata(move _4);
+ _0 = const 1_usize;
StorageDead(_4);
- StorageDead(_2);
+ nop;
return;
}
}

32 changes: 32 additions & 0 deletions tests/mir-opt/gvn.meta_of_ref_to_slice.GVN.panic-unwind.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `meta_of_ref_to_slice` before GVN
+ // MIR for `meta_of_ref_to_slice` after GVN

fn meta_of_ref_to_slice(_1: *const i32) -> usize {
debug x => _1;
let mut _0: usize;
let _2: *const [i32];
let mut _3: *const i32;
let mut _4: *const [i32];
scope 1 {
debug ptr => _2;
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = _1;
- _2 = *const [i32] from (move _3, const 1_usize);
+ _2 = *const [i32] from (_1, const 1_usize);
StorageDead(_3);
StorageLive(_4);
_4 = _2;
- _0 = PtrMetadata(move _4);
+ _0 = const 1_usize;
StorageDead(_4);
- StorageDead(_2);
+ nop;
return;
}
}

Loading
Loading