diff --git a/compiler/rustc_index/src/bit_set.rs b/compiler/rustc_index/src/bit_set.rs index 02e5feb6c5f21..261df23584f53 100644 --- a/compiler/rustc_index/src/bit_set.rs +++ b/compiler/rustc_index/src/bit_set.rs @@ -324,6 +324,15 @@ impl DenseBitSet { // out-of-domain bits, so we need to clear them. self.clear_excess_bits(); } + + /// Sets `self = self & (a | b)` without allocating a temporary for `a | b`. + /// + /// Returns `true` if `self` changed. + pub fn intersect_with_union(&mut self, a: &DenseBitSet, b: &DenseBitSet) -> bool { + assert_eq!(self.domain_size, a.domain_size); + assert_eq!(self.domain_size, b.domain_size); + bitwise3(&mut self.words, &a.words, &b.words, |s, a, b| s & (a | b)) + } } // dense REL dense @@ -1084,6 +1093,23 @@ where changed != 0 } +#[inline] +fn bitwise3(out_vec: &mut [Word], in_vec1: &[Word], in_vec2: &[Word], op: Op) -> bool +where + Op: Fn(Word, Word, Word) -> Word, +{ + assert_eq!(out_vec.len(), in_vec1.len()); + assert_eq!(out_vec.len(), in_vec2.len()); + let mut changed = 0; + for ((out_elem, in_elem1), in_elem2) in iter::zip(iter::zip(out_vec, in_vec1), in_vec2) { + let old_val = *out_elem; + let new_val = op(old_val, *in_elem1, *in_elem2); + *out_elem = new_val; + changed |= old_val ^ new_val; + } + changed != 0 +} + /// Does this bitwise operation change `out_vec`? #[inline] fn bitwise_changes(out_vec: &[Word], in_vec: &[Word], op: Op) -> bool diff --git a/compiler/rustc_index/src/interval.rs b/compiler/rustc_index/src/interval.rs index dda5253e7c547..b4be0534deb1d 100644 --- a/compiler/rustc_index/src/interval.rs +++ b/compiler/rustc_index/src/interval.rs @@ -1,6 +1,7 @@ use std::iter::Step; use std::marker::PhantomData; -use std::ops::{Bound, Range, RangeBounds}; +use std::ops::{Bound, RangeBounds}; +use std::range::RangeInclusive; use smallvec::SmallVec; @@ -59,11 +60,14 @@ impl IntervalSet { } /// Iterates through intervals stored in the set, in order. - pub fn iter_intervals(&self) -> impl Iterator> + pub fn iter_intervals(&self) -> impl Iterator> where I: Step, { - self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) + self.map.iter().map(|&(start, end)| RangeInclusive { + start: I::new(start as usize), + last: I::new(end as usize), + }) } /// Returns true if we increased the number of elements present. @@ -164,6 +168,35 @@ impl IntervalSet { ); } + /// Specialized version of `insert` when we know that the inserted range is + /// *after* any contained. + pub fn append_range(&mut self, range: impl RangeBounds + Clone) { + let start = inclusive_start(range.clone()); + let Some(end) = inclusive_end(self.domain, range) else { + // empty range + return; + }; + if start > end { + return; + } + + if let Some((_, last_end)) = self.map.last_mut() { + assert!(*last_end < start); + if start == *last_end + 1 { + *last_end = end; + } else { + self.map.push((start, end)); + } + } else { + self.map.push((start, end)); + } + + debug_assert!( + self.check_invariants(), + "wrong intervals after append {start:?}..={end:?} to {self:?}" + ); + } + pub fn contains(&self, needle: I) -> bool { let needle = needle.index() as u32; let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else { @@ -174,17 +207,38 @@ impl IntervalSet { needle <= *prev_end } + /// Returns whether any point in `range` is contained in the set. + pub fn intersects_range(&self, range: impl RangeBounds + Clone) -> bool { + let start = inclusive_start(range.clone()); + let Some(end) = inclusive_end(self.domain, range) else { + // empty range + return false; + }; + if start > end { + return false; + } + + // Find the last interval whose start is <= end. + let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else { + // All ranges in the map start after the new range's end + return false; + }; + let (_, prev_end) = &self.map[last]; + start <= *prev_end + } + pub fn superset(&self, other: &IntervalSet) -> bool where I: Step, { let mut sup_iter = self.iter_intervals(); let mut current = None; - let contains = |sup: Range, sub: Range, current: &mut Option>| { - if sup.end < sub.start { - // if `sup.end == sub.start`, the next sup doesn't contain `sub.start` + let contains = |sup: RangeInclusive, + sub: RangeInclusive, + current: &mut Option>| { + if sup.last < sub.start { None // continue to the next sup - } else if sup.end >= sub.end && sup.start <= sub.start { + } else if sup.last >= sub.last && sup.start <= sub.start { *current = Some(sup); // save the current sup Some(true) } else { @@ -194,8 +248,8 @@ impl IntervalSet { other.iter_intervals().all(|sub| { current .take() - .and_then(|sup| contains(sup, sub.clone(), &mut current)) - .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current))) + .and_then(|sup| contains(sup, sub, &mut current)) + .or_else(|| sup_iter.find_map(|sup| contains(sup, sub, &mut current))) .unwrap_or(false) }) } @@ -212,11 +266,11 @@ impl IntervalSet { let mut other_current = other_iter.next()?; loop { - if self_current.end <= other_current.start { + if self_current.last < other_current.start { self_current = self_iter.next()?; continue; } - if other_current.end <= self_current.start { + if other_current.last < self_current.start { other_current = other_iter.next()?; continue; } @@ -340,6 +394,12 @@ impl SparseIntervalMatrix { self.rows.get(row) } + pub fn clear_row(&mut self, row: R) { + if let Some(row) = self.rows.get_mut(row) { + row.clear(); + } + } + fn ensure_row(&mut self, row: R) -> &mut IntervalSet { self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)) } @@ -363,6 +423,16 @@ impl SparseIntervalMatrix { write_row.union(read_row) } + pub fn disjoint_rows(&self, a: R, b: R) -> bool + where + C: Step, + { + let (Some(a), Some(b)) = (self.rows.get(a), self.rows.get(b)) else { + return true; + }; + a.disjoint(b) + } + pub fn insert_all_into_row(&mut self, row: R) { self.ensure_row(row).insert_all(); } @@ -379,6 +449,10 @@ impl SparseIntervalMatrix { self.ensure_row(row).append(point) } + pub fn append_range(&mut self, row: R, range: impl RangeBounds + Clone) { + self.ensure_row(row).append_range(range) + } + pub fn contains(&self, row: R, point: C) -> bool { self.row(row).is_some_and(|r| r.contains(point)) } diff --git a/compiler/rustc_index/src/interval/tests.rs b/compiler/rustc_index/src/interval/tests.rs index 375af60f66207..cf3222e6c6572 100644 --- a/compiler/rustc_index/src/interval/tests.rs +++ b/compiler/rustc_index/src/interval/tests.rs @@ -5,7 +5,7 @@ fn insert_collapses() { let mut set = IntervalSet::::new(10000); set.insert_range(9831..=9837); set.insert_range(43..=9830); - assert_eq!(set.iter_intervals().collect::>(), [43..9838]); + assert_eq!(set.iter_intervals().collect::>(), [(43..=9837).into()]); } #[test] diff --git a/compiler/rustc_middle/src/mir/syntax.rs b/compiler/rustc_middle/src/mir/syntax.rs index 3d320d4cf8383..afe35de3b165e 100644 --- a/compiler/rustc_middle/src/mir/syntax.rs +++ b/compiler/rustc_middle/src/mir/syntax.rs @@ -124,7 +124,6 @@ pub enum RuntimePhase { /// disallowed: /// * [`TerminatorKind::Yield`] /// * [`TerminatorKind::CoroutineDrop`] - /// * [`Rvalue::Aggregate`] for any `AggregateKind` except `Array` /// * [`Rvalue::CopyForDeref`] /// * [`PlaceElem::OpaqueCast`] /// * [`LocalInfo::DerefTemp`](super::LocalInfo::DerefTemp) @@ -1442,9 +1441,6 @@ pub enum Rvalue<'tcx> { /// This is needed because dataflow analysis needs to distinguish /// `dest = Foo { x: ..., y: ... }` from `dest.x = ...; dest.y = ...;` in the case that `Foo` /// has a destructor. - /// - /// Disallowed after deaggregation for all aggregate kinds except `Array` and `Coroutine`. After - /// coroutine lowering, `Coroutine` aggregate kinds are disallowed too. Aggregate(Box>, IndexVec>), /// A CopyForDeref is equivalent to a read from a place at the diff --git a/compiler/rustc_mir_dataflow/src/framework/direction.rs b/compiler/rustc_mir_dataflow/src/framework/direction.rs index b15b5c07ce382..dbfbe461e4686 100644 --- a/compiler/rustc_mir_dataflow/src/framework/direction.rs +++ b/compiler/rustc_mir_dataflow/src/framework/direction.rs @@ -214,7 +214,7 @@ impl Direction for Backward { ) where A: Analysis<'tcx>, { - vis.visit_block_end(state); + vis.visit_block_end(state, block); let loc = Location { block, statement_index: block_data.statements.len() }; let term = block_data.terminator(); @@ -231,7 +231,7 @@ impl Direction for Backward { vis.visit_after_primary_statement_effect(analysis, state, stmt, loc); } - vis.visit_block_start(state); + vis.visit_block_start(state, block); } } @@ -394,7 +394,7 @@ impl Direction for Forward { ) where A: Analysis<'tcx>, { - vis.visit_block_start(state); + vis.visit_block_start(state, block); for (statement_index, stmt) in block_data.statements.iter().enumerate() { let loc = Location { block, statement_index }; @@ -411,6 +411,6 @@ impl Direction for Forward { analysis.apply_primary_terminator_effect(state, term, loc); vis.visit_after_primary_terminator_effect(analysis, state, term, loc); - vis.visit_block_end(state); + vis.visit_block_end(state, block); } } diff --git a/compiler/rustc_mir_dataflow/src/framework/graphviz.rs b/compiler/rustc_mir_dataflow/src/framework/graphviz.rs index 6c0f2e8d73058..28320c29ce2dd 100644 --- a/compiler/rustc_mir_dataflow/src/framework/graphviz.rs +++ b/compiler/rustc_mir_dataflow/src/framework/graphviz.rs @@ -660,13 +660,13 @@ where A: Analysis<'tcx>, A::Domain: DebugWithContext, { - fn visit_block_start(&mut self, state: &A::Domain) { + fn visit_block_start(&mut self, state: &A::Domain, _block: BasicBlock) { if A::Direction::IS_FORWARD { self.prev_state.clone_from(state); } } - fn visit_block_end(&mut self, state: &A::Domain) { + fn visit_block_end(&mut self, state: &A::Domain, _block: BasicBlock) { if A::Direction::IS_BACKWARD { self.prev_state.clone_from(state); } diff --git a/compiler/rustc_mir_dataflow/src/framework/visitor.rs b/compiler/rustc_mir_dataflow/src/framework/visitor.rs index 46940c6ab62fc..f5693bcffd891 100644 --- a/compiler/rustc_mir_dataflow/src/framework/visitor.rs +++ b/compiler/rustc_mir_dataflow/src/framework/visitor.rs @@ -46,7 +46,7 @@ pub trait ResultsVisitor<'tcx, A> where A: Analysis<'tcx>, { - fn visit_block_start(&mut self, _state: &A::Domain) {} + fn visit_block_start(&mut self, _state: &A::Domain, _block: BasicBlock) {} /// Called after the "early" effect of the given statement is applied to `state`. fn visit_after_early_statement_effect( @@ -90,5 +90,5 @@ where ) { } - fn visit_block_end(&mut self, _state: &A::Domain) {} + fn visit_block_end(&mut self, _state: &A::Domain, _block: BasicBlock) {} } diff --git a/compiler/rustc_mir_dataflow/src/impls/mod.rs b/compiler/rustc_mir_dataflow/src/impls/mod.rs index 6d573e1c00e1c..507932c79ffaa 100644 --- a/compiler/rustc_mir_dataflow/src/impls/mod.rs +++ b/compiler/rustc_mir_dataflow/src/impls/mod.rs @@ -1,6 +1,7 @@ mod borrowed_locals; mod initialized; mod liveness; +mod precise_liveness; mod storage_liveness; pub use self::borrowed_locals::{MaybeBorrowedLocals, borrowed_locals}; @@ -12,6 +13,9 @@ pub use self::liveness::{ DefUse, MaybeLiveLocals, MaybeTransitiveLiveLocals, TransferFunction as LivenessTransferFunction, }; +pub use self::precise_liveness::{ + SplitPointEffect, SplitPointIndex, dump_liveness_matrix, liveness_matrix, +}; pub use self::storage_liveness::{ MaybeRequiresStorage, MaybeStorageDead, MaybeStorageLive, always_storage_live_locals, }; diff --git a/compiler/rustc_mir_dataflow/src/impls/precise_liveness.rs b/compiler/rustc_mir_dataflow/src/impls/precise_liveness.rs new file mode 100644 index 0000000000000..f2d91dea742b4 --- /dev/null +++ b/compiler/rustc_mir_dataflow/src/impls/precise_liveness.rs @@ -0,0 +1,560 @@ +use std::fmt; + +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::interval::SparseIntervalMatrix; +use rustc_middle::mir::visit::{ + MutatingUseContext, NonMutatingUseContext, PlaceContext, VisitPlacesWith, Visitor, +}; +use rustc_middle::mir::{self, BasicBlock, Local, Location, MirDumper, PassWhere, Place}; +use rustc_middle::ty::TyCtxt; +use tracing::trace; + +use crate::fmt::DebugWithContext; +use crate::impls::{DefUse, MaybeBorrowedLocals, MaybeLiveLocals}; +use crate::points::{DenseLocationMap, PointIndex}; +use crate::{Analysis, GenKill, JoinSemiLattice, ResultsVisitor, visit_reachable_results}; + +struct KillPointsVisitor<'a> { + kill_points: &'a mut Vec<(Local, Location)>, + live_on_entry: &'a mut IndexVec>, +} + +impl<'tcx> ResultsVisitor<'tcx, MaybeLiveLocals> for KillPointsVisitor<'_> { + fn visit_block_start(&mut self, state: &DenseBitSet, block: BasicBlock) { + self.live_on_entry[block].clone_from(state); + } + + fn visit_after_early_statement_effect( + &mut self, + _analysis: &MaybeLiveLocals, + state: &DenseBitSet, + statement: &mir::Statement<'tcx>, + location: Location, + ) { + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + // Ignore non-uses. + match ctxt { + PlaceContext::NonMutatingUse(_) | PlaceContext::MutatingUse(_) => {} + PlaceContext::NonUse(_) => return, + } + + // If a local is used in a statement but is dead after it then this + // location is a kill point. + if !state.contains(place.local) { + self.kill_points.push((place.local, location)); + } + }) + .visit_statement(statement, location); + } + + fn visit_after_early_terminator_effect( + &mut self, + _analysis: &MaybeLiveLocals, + state: &DenseBitSet, + terminator: &mir::Terminator<'tcx>, + location: Location, + ) { + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + // Ignore non-uses (they don't do anything) and edge uses (kill + // points for those go at the start of the corresponding successor). + match ctxt { + PlaceContext::MutatingUse( + MutatingUseContext::AsmOutput + | MutatingUseContext::Call + | MutatingUseContext::Yield, + ) + | PlaceContext::NonUse(_) => return, + PlaceContext::NonMutatingUse(_) | PlaceContext::MutatingUse(_) => {} + } + + // If a local is used in a terminator but is dead after it then this + // location is a kill point. + if !state.contains(place.local) { + self.kill_points.push((place.local, location)); + } + }) + .visit_terminator(terminator, location); + } +} + +#[derive(Debug, PartialEq, Eq)] +struct Domain { + maybe_live: DenseBitSet, + maybe_borrowed: DenseBitSet, +} + +impl Clone for Domain { + fn clone(&self) -> Self { + Domain { maybe_live: self.maybe_live.clone(), maybe_borrowed: self.maybe_borrowed.clone() } + } + + // Data flow engine when possible uses `clone_from` for domain values. + // Providing an implementation will avoid some intermediate memory allocations. + fn clone_from(&mut self, other: &Self) { + self.maybe_live.clone_from(&other.maybe_live); + self.maybe_borrowed.clone_from(&other.maybe_borrowed); + } +} + +impl JoinSemiLattice for Domain { + fn join(&mut self, other: &Self) -> bool { + self.maybe_live.join(&other.maybe_live) | self.maybe_borrowed.join(&other.maybe_borrowed) + } +} + +impl DebugWithContext for Domain { + fn fmt_with(&self, ctxt: &C, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("maybe_live: ")?; + self.maybe_live.fmt_with(ctxt, f)?; + f.write_str("maybe_borrowed: ")?; + self.maybe_borrowed.fmt_with(ctxt, f)?; + Ok(()) + } + + fn fmt_diff_with(&self, old: &Self, ctxt: &C, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self == old { + return Ok(()); + } + + if self.maybe_live != old.maybe_live { + f.write_str("maybe_live: ")?; + self.maybe_live.fmt_diff_with(&old.maybe_live, ctxt, f)?; + f.write_str("\n")?; + } + + if self.maybe_borrowed != old.maybe_borrowed { + f.write_str("maybe_borrowed: ")?; + self.maybe_borrowed.fmt_diff_with(&old.maybe_borrowed, ctxt, f)?; + f.write_str("\n")?; + } + + Ok(()) + } +} + +struct PreciseLiveness<'a> { + kill_point_map: &'a IndexVec, + live_on_entry: &'a IndexVec>, + points: &'a DenseLocationMap, +} + +impl PreciseLiveness<'_> { + fn apply_block_start_effect(&self, state: &mut Domain, block: BasicBlock) { + // Only keep locals that are either live or borrowed. + // + // Notably this kills any dead results produced by a predecessor's + // terminator. + state.maybe_live.intersect_with_union(&self.live_on_entry[block], &state.maybe_borrowed); + } +} + +impl<'tcx> Analysis<'tcx> for PreciseLiveness<'_> { + type Domain = Domain; + + const NAME: &'static str = "precise_liveness"; + + fn bottom_value(&self, body: &mir::Body<'tcx>) -> Domain { + Domain { + maybe_live: DenseBitSet::new_empty(body.local_decls.len()), + maybe_borrowed: DenseBitSet::new_empty(body.local_decls.len()), + } + } + + fn initialize_start_block(&self, body: &mir::Body<'tcx>, state: &mut Domain) { + // Function arguments start out as live. + for arg in body.args_iter() { + state.maybe_live.gen_(arg); + } + } + + fn apply_primary_statement_effect( + &self, + state: &mut Domain, + statement: &mir::Statement<'tcx>, + location: Location, + ) { + if location.statement_index == 0 { + self.apply_block_start_effect(state, location.block); + } + + // StorageDead always kills a local, even if it has been borrowed. + if let mir::StatementKind::StorageDead(local) = statement.kind { + state.maybe_live.kill(local); + state.maybe_borrowed.kill(local); + return; + } + + MaybeBorrowedLocals::transfer_function(&mut state.maybe_borrowed) + .visit_statement(statement, location); + + // Kill moved operands if the whole local was moved. + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + if ctxt == PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) { + if let Some(local) = place.as_local() { + state.maybe_live.kill(local); + state.maybe_borrowed.kill(local); + } + } + }) + .visit_statement(statement, location); + + // Gen destination places. + VisitPlacesWith(|place: Place<'tcx>, ctxt| match DefUse::for_place(place, ctxt) { + DefUse::Def | DefUse::PartialWrite => state.maybe_live.gen_(place.local), + DefUse::Use | DefUse::NonUse => {} + }) + .visit_statement(statement, location); + + // Apply kill points at this statement: if a variable is dead + // then it doesn't need storage, *except* if its address has been taken. + let point = self.points.point_from_location(location); + for &(local, _) in self.kill_point_map[point] { + if !state.maybe_borrowed.contains(local) { + state.maybe_live.kill(local); + } + } + } + + fn apply_primary_terminator_effect<'mir>( + &self, + state: &mut Domain, + terminator: &'mir mir::Terminator<'tcx>, + location: Location, + ) -> mir::TerminatorEdges<'mir, 'tcx> { + if location.statement_index == 0 { + self.apply_block_start_effect(state, location.block); + } + + MaybeBorrowedLocals::transfer_function(&mut state.maybe_borrowed) + .visit_terminator(terminator, location); + + // Kill moved operands if the whole local was moved. Also kill dropped + // places if the entire local was dropped. + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + if let PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) + | PlaceContext::MutatingUse(MutatingUseContext::Drop) = ctxt + { + if let Some(local) = place.as_local() { + state.maybe_live.kill(local); + state.maybe_borrowed.kill(local); + } + } + }) + .visit_terminator(terminator, location); + + // Gen destination places. + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + // These are handled through `apply_call_return_effect`. + if let PlaceContext::MutatingUse( + MutatingUseContext::AsmOutput + | MutatingUseContext::Call + | MutatingUseContext::Yield, + ) = ctxt + { + return; + } + + match DefUse::for_place(place, ctxt) { + DefUse::Def | DefUse::PartialWrite => state.maybe_live.gen_(place.local), + DefUse::Use | DefUse::NonUse => {} + } + }) + .visit_terminator(terminator, location); + + terminator.edges() + } + + fn apply_call_return_effect( + &self, + state: &mut Domain, + _block: BasicBlock, + return_places: mir::CallReturnPlaces<'_, 'tcx>, + ) { + return_places.for_each(|place| state.maybe_live.gen_(place.local)); + } +} + +/// Different "phases" of a single MIR statement, used to describe how +/// overlapping operands are handled. +/// +/// As a general rule, source operands are read in the `Early` phase and +/// destination places are written in the `Late` phase. +#[derive(Copy, Clone, Debug)] +pub enum SplitPointEffect { + Early = 0, + Late = 1, +} + +rustc_index::newtype_index! { + /// A `PointIndex` with the lower bit encoding early/late inside a statement. + /// + /// This is used to model overlap constraints within a MIR statement: if a + /// source/destination are allowed to overlap then the source is read in + /// `SplitPointEffect::Early` and the write is done in + /// `SplitPointEffect::Late`. + #[orderable] + #[debug_format = "SplitPointIndex({})"] + pub struct SplitPointIndex {} +} + +impl SplitPointIndex { + pub fn new(point: PointIndex, effect: SplitPointEffect) -> SplitPointIndex { + let index = (point.as_u32() << 1) | (effect as u32); + SplitPointIndex::from_u32(index) + } + + pub fn point(self) -> PointIndex { + PointIndex::from_u32(self.as_u32() >> 1) + } + + pub fn effect(self) -> SplitPointEffect { + match self.as_u32() & 1 { + 0 => SplitPointEffect::Early, + 1 => SplitPointEffect::Late, + _ => unreachable!(), + } + } +} + +fn compute_kill_points<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mir::Body<'tcx>, + pass_name: Option<&'static str>, +) -> (Vec<(Local, Location)>, IndexVec>) { + let maybe_live_locals = MaybeLiveLocals.iterate_to_fixpoint(tcx, body, pass_name); + let mut kill_points = vec![]; + let mut live_on_entry = IndexVec::from_elem_n( + DenseBitSet::new_empty(body.local_decls.len()), + body.basic_blocks.len(), + ); + let mut visitor = + KillPointsVisitor { kill_points: &mut kill_points, live_on_entry: &mut live_on_entry }; + visit_reachable_results(body, &maybe_live_locals, &mut visitor); + trace!(?kill_points); + trace!(?live_on_entry); + (kill_points, live_on_entry) +} + +fn kill_point_map<'a>( + kill_points: &'a [(Local, Location)], + points: &DenseLocationMap, +) -> IndexVec { + let mut out = IndexVec::from_elem_n(&[][..], points.num_points()); + for chunk in kill_points.chunk_by(|a, b| a.1 == b.1) { + let point = points.point_from_location(chunk[0].1); + trace!("Kill points at {:?}: {:?}", chunk[0].1, chunk); + out[point] = chunk; + } + out +} + +/// Helper type to construct a `SparseIntervalMatrix`. +struct MatrixBuilder { + matrix: SparseIntervalMatrix, + range_start: IndexVec>, +} + +impl MatrixBuilder { + fn gen_(&mut self, local: Local, point: PointIndex, effect: SplitPointEffect) { + let split_point = SplitPointIndex::new(point, effect); + + // No-op if the local is already live. + self.range_start[local].get_or_insert(split_point); + } + + fn kill(&mut self, local: Local, point: PointIndex, effect: SplitPointEffect) { + let end = SplitPointIndex::new(point, effect); + + // No-op if the local is already dead. + if let Some(start) = self.range_start[local].take() { + debug_assert!(end >= start); + self.matrix.append_range(local, start..=end); + } + } +} + +pub fn liveness_matrix<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mir::Body<'tcx>, + points: &DenseLocationMap, + pass_name: Option<&'static str>, +) -> SparseIntervalMatrix { + let (kill_points, live_on_entry) = compute_kill_points(tcx, body, pass_name); + let kill_point_map = &kill_point_map(&kill_points, points); + let mut results = PreciseLiveness { kill_point_map, live_on_entry: &live_on_entry, points } + .iterate_to_fixpoint(tcx, body, pass_name); + + let mut builder = MatrixBuilder { + matrix: SparseIntervalMatrix::new(points.num_points() * 2), + range_start: IndexVec::from_elem_n(None, body.local_decls.len()), + }; + for (block, block_data) in body.basic_blocks.iter_enumerated() { + // We can mutate the state in-place since we're not using it any more + // after this point. + let state = &mut results.entry_states[block]; + + // Only keep locals that are either live or borrowed. + // + // Notably this kills any dead results produced by a predecessor's + // terminator. + state.maybe_live.intersect_with_union(&live_on_entry[block], &state.maybe_borrowed); + + for local in state.maybe_live.iter() { + builder.gen_(local, points.entry_point(block), SplitPointEffect::Early); + } + + for (statement_index, statement) in block_data.statements.iter().enumerate() { + let location = Location { block, statement_index }; + let point = points.point_from_location(location); + + // StorageDead always kills a local, even if it has been borrowed. + if let mir::StatementKind::StorageDead(local) = statement.kind { + builder.kill(local, point, SplitPointEffect::Late); + state.maybe_borrowed.kill(local); + continue; + } + + MaybeBorrowedLocals::transfer_function(&mut state.maybe_borrowed) + .visit_statement(statement, location); + + // Kill moved operands if the whole local was moved. + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + if ctxt == PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) { + if let Some(local) = place.as_local() { + builder.kill(local, point, SplitPointEffect::Early); + state.maybe_borrowed.kill(local); + } + } + }) + .visit_statement(statement, location); + + // Kill any locals which are no longer used after this statement, + // but only if they have not been borrowed. + for &(local, _) in kill_point_map[point] { + if !state.maybe_borrowed.contains(local) { + builder.kill(local, point, SplitPointEffect::Early); + } + } + + // Gen destination places. + VisitPlacesWith(|place: Place<'tcx>, ctxt| match DefUse::for_place(place, ctxt) { + DefUse::Def | DefUse::PartialWrite => { + builder.gen_(place.local, point, SplitPointEffect::Late) + } + DefUse::Use | DefUse::NonUse => {} + }) + .visit_statement(statement, location); + + // Kill any dead destination places: they will only appear at + // the late point of the statement they are generated in, which is + // sufficient for determining overlap. + for &(local, _) in kill_point_map[point] { + if !state.maybe_borrowed.contains(local) { + builder.kill(local, point, SplitPointEffect::Late); + } + } + } + + let location = Location { block, statement_index: block_data.statements.len() }; + let point = points.point_from_location(location); + let terminator = block_data.terminator(); + + MaybeBorrowedLocals::transfer_function(&mut state.maybe_borrowed) + .visit_terminator(terminator, location); + + // Kill moved operands if the whole local was moved. Also kill dropped + // places if the entire local was dropped. + VisitPlacesWith(|place: Place<'tcx>, ctxt| { + if let PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) + | PlaceContext::MutatingUse(MutatingUseContext::Drop) = ctxt + { + if let Some(local) = place.as_local() { + builder.kill(local, point, SplitPointEffect::Early); + state.maybe_borrowed.kill(local); + } + } + }) + .visit_terminator(terminator, location); + + // Kill any locals which are no longer used after this terminator, + // but only if they have not been borrowed. + for &(local, _) in kill_point_map[point] { + if !state.maybe_borrowed.contains(local) { + builder.kill(local, point, SplitPointEffect::Early); + } + } + + // Gen destination places. + VisitPlacesWith(|place: Place<'tcx>, ctxt| match DefUse::for_place(place, ctxt) { + DefUse::Def | DefUse::PartialWrite => { + builder.gen_(place.local, point, SplitPointEffect::Late) + } + DefUse::Use | DefUse::NonUse => {} + }) + .visit_terminator(terminator, location); + + // Move arguments to a call are treated specially: the place that they + // represent is passed directly to the callee, which means that they are + // not allowed to alias any other move operand or the destination place. + // This is represented here by extending their live range to the late + // part, making it overlap with that of the destination place. + // + // Notably, this *doesn't* apply to TailCall. + if let mir::TerminatorKind::Call { + func: _, + args, + destination: _, + target: _, + unwind: _, + call_source: _, + fn_span: _, + } = &terminator.kind + { + for arg in args { + if let mir::Operand::Move(place) = arg.node { + builder.gen_(place.local, point, SplitPointEffect::Late); + builder.kill(place.local, point, SplitPointEffect::Late); + } + } + } + + // End the lifetimes of all locals at the end of the block. Successor + // blocks (which may not be continuous in the index space!) will + // initialize the lifetimes again from their entry state. + for local in builder.range_start.indices() { + builder.kill(local, point, SplitPointEffect::Late); + } + } + + builder.matrix +} + +pub fn dump_liveness_matrix<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mir::Body<'tcx>, + pass_name: &'static str, + points: &DenseLocationMap, + matrix: &SparseIntervalMatrix, +) { + let locals_live_at = |split_point| { + matrix.rows().filter(|&r| matrix.contains(r, split_point)).collect::>() + }; + + if let Some(dumper) = MirDumper::new(tcx, pass_name, body) { + let extra_data = &|pass_where, w: &mut dyn std::io::Write| { + if let PassWhere::BeforeLocation(loc) = pass_where { + let point = points.point_from_location(loc); + let split_point = SplitPointIndex::new(point, SplitPointEffect::Early); + let live = locals_live_at(split_point); + writeln!(w, " // {loc:?}-early => {live:?}")?; + let split_point = SplitPointIndex::new(point, SplitPointEffect::Late); + let live = locals_live_at(split_point); + writeln!(w, " // {loc:?}-late => {live:?}")?; + } + Ok(()) + }; + + dumper.set_extra_data(extra_data).dump_mir(body) + } +} diff --git a/compiler/rustc_mir_dataflow/src/points.rs b/compiler/rustc_mir_dataflow/src/points.rs index e3d1e04a319ba..92513b552410e 100644 --- a/compiler/rustc_mir_dataflow/src/points.rs +++ b/compiler/rustc_mir_dataflow/src/points.rs @@ -56,6 +56,15 @@ impl DenseLocationMap { PointIndex::new(start_index) } + /// Returns the `PointIndex` for the terminator in the given `BasicBlock`. O(1). + #[inline] + pub fn terminator(&self, block: BasicBlock) -> PointIndex { + let next_block = BasicBlock::new(block.index() + 1); + let next_start_index = + *self.statements_before_block.get(next_block).unwrap_or(&self.num_points); + PointIndex::new(next_start_index - 1) + } + /// Return the PointIndex for the block start of this index. #[inline] pub fn to_block_start(&self, index: PointIndex) -> PointIndex { diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 3be3c19ab198e..40081aef17f41 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -153,7 +153,7 @@ pub(super) struct DestinationPropagation; impl<'tcx> crate::MirPass<'tcx> for DestinationPropagation { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 2 + false && sess.mir_opt_level() >= 2 } #[tracing::instrument(level = "trace", skip(self, tcx, body))] diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 91dfffcf1a6a5..e68ea29d173b3 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -164,6 +164,7 @@ declare_passes! { mod lower_slice_len : LowerSliceLenCalls; mod match_branches : MatchBranchSimplification; mod mentioned_items : MentionedItems; + mod move_elimination : MoveElimination; mod multiple_return_terminators : MultipleReturnTerminators; mod post_drop_elaboration : CheckLiveDrops; mod prettify : ReorderBasicBlocks, ReorderLocals; @@ -763,6 +764,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<' ©_prop::CopyProp, &dead_store_elimination::DeadStoreElimination::Final, &dest_prop::DestinationPropagation, + &move_elimination::MoveElimination, &simplify::SimplifyLocals::Final, &multiple_return_terminators::MultipleReturnTerminators, &large_enums::EnumSizeOpt { discrepancy: 128 }, diff --git a/compiler/rustc_mir_transform/src/move_elimination.rs b/compiler/rustc_mir_transform/src/move_elimination.rs new file mode 100644 index 0000000000000..0aad46d70178a --- /dev/null +++ b/compiler/rustc_mir_transform/src/move_elimination.rs @@ -0,0 +1,909 @@ +use rustc_abi::{FieldIdx, VariantIdx}; +use rustc_const_eval::util::most_packed_projection; +use rustc_data_structures::fx::FxHashMap; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::interval::SparseIntervalMatrix; +use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, VisitPlacesWith, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_mir_dataflow::impls::{ + SplitPointEffect, SplitPointIndex, dump_liveness_matrix, liveness_matrix, +}; +use rustc_mir_dataflow::points::DenseLocationMap; +use tracing::{debug, trace}; + +use crate::patch::MirPatch; + +pub(super) struct MoveElimination; + +impl<'tcx> crate::MirPass<'tcx> for MoveElimination { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + true && sess.mir_opt_level() >= 2 + } + + #[tracing::instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + trace!(?def_id); + + let points = DenseLocationMap::new(body); + let mut liveness_matrix = + liveness_matrix(tcx, body, &points, Some("MoveElimination.liveness")); + + dump_liveness_matrix(tcx, body, "MoveElimination.pre-liveness", &points, &liveness_matrix); + + let mut unprojectable_locals = UnprojectableLocals::find(body); + trace!(?unprojectable_locals); + + let remapped_locals = + PlaceUnification::run(tcx, body, &mut liveness_matrix, &mut unprojectable_locals); + + apply_mappings(tcx, body, &remapped_locals); + + dump_liveness_matrix(tcx, body, "MoveElimination.post-liveness", &points, &liveness_matrix); + + if tcx.sess.emit_lifetime_markers() { + reconstruct_storage(body, &points, &liveness_matrix); + } + + apply_alias_fixup(tcx, body); + } + + fn is_required(&self) -> bool { + false + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Unprojectable locals + +/// Set of locals which can only be replaced with another local, instead of +/// an arbitrary place. This is usually because it is used directly as a +/// `Local` outside of a place (e.g. `Index` projections). +#[derive(Debug)] +struct UnprojectableLocals { + locals: DenseBitSet, +} + +impl UnprojectableLocals { + fn find(body: &Body<'_>) -> DenseBitSet { + let mut out = Self { locals: DenseBitSet::new_empty(body.local_decls.len()) }; + + // Arguments and return places have fixed roles and cannot be replaced + // with projected locals. + out.locals.insert(RETURN_PLACE); + for arg in body.args_iter() { + out.locals.insert(arg); + } + + out.visit_body(body); + out.locals + } +} + +impl<'tcx> Visitor<'tcx> for UnprojectableLocals { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + // We can't add more projections before a first position Deref projection. + if place.is_indirect() { + trace!( + "unprojectable local {:?} due to use as deref base at {location:?}", + place.local + ); + self.locals.insert(place.local); + } + + // Only call visit_local for projections, not the base local. + self.visit_projection(place.as_ref(), context, location); + } + + fn visit_local(&mut self, local: Local, context: PlaceContext, location: Location) { + // Ignore uses in storage statements, we're going to remove all of those + // anyways. + if let PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) = + context + { + return; + } + + // If this is reached, it means that this is a bare local used outside + // of a place, which means it cannot be replaced with a projection of + // another local. + trace!("unprojectable local {local:?} at {location:?} ({context:?})"); + self.locals.insert(local); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Local unification + +struct PlaceUnification<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, + liveness_matrix: &'a mut SparseIntervalMatrix, + unprojectable_locals: &'a mut DenseBitSet, + remapped_locals: IndexVec>>, +} + +impl<'tcx> PlaceUnification<'_, 'tcx> { + fn run( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + liveness_matrix: &mut SparseIntervalMatrix, + unprojectable_locals: &mut DenseBitSet, + ) -> IndexVec>> { + let mut visitor = PlaceUnification { + tcx, + body, + liveness_matrix, + unprojectable_locals, + remapped_locals: IndexVec::from_elem_n(None, body.local_decls.len()), + }; + visitor.visit_body(body); + + // Finalize the mappings by transitively resolving all locals to their + // new final place. + for local in visitor.remapped_locals.indices() { + if let Some(place) = visitor.remapped_locals[local] { + let place = visitor.resolve_place(place); + visitor.remapped_locals[local] = Some(place); + trace!("Remapped {local:?} to {place:?}"); + } + } + + visitor.remapped_locals + } + + #[tracing::instrument(ret, level = "trace", skip(self))] + fn resolve_place(&self, mut place: Place<'tcx>) -> Place<'tcx> { + while let Some(new_place) = self.remapped_locals[place.local] { + place = new_place.project_deeper(place.projection, self.tcx); + } + place + } + + #[tracing::instrument(ret, level = "trace", skip(self))] + fn can_unify_places(&self, a: Place<'tcx>, b: Place<'tcx>) -> Option<(Local, Place<'tcx>)> { + let a = self.resolve_place(a); + let b = self.resolve_place(b); + + if a.local == b.local { + if a.projection != b.projection { + trace!("cannot unify same local with different projections"); + } + return None; + } + + let (local, place) = match (a.as_local(), b.as_local()) { + (None, None) => { + trace!("cannot unify 2 places that both have projections"); + return None; + } + (None, Some(b)) => { + if self.unprojectable_locals.contains(b) { + trace!("cannot unify {b:?} which cannot be projected"); + return None; + } + (b, a) + } + (Some(a), None) => { + if self.unprojectable_locals.contains(a) { + trace!("cannot unify {a:?} which cannot be projected"); + return None; + } + (a, b) + } + (Some(a), Some(b)) => match (self.body.local_kind(a), self.body.local_kind(b)) { + ( + LocalKind::Arg | LocalKind::ReturnPointer, + LocalKind::Arg | LocalKind::ReturnPointer, + ) => { + trace!("cannot unify {a:?} and {b:?} which are both arguments or return place"); + return None; + } + (LocalKind::Arg | LocalKind::ReturnPointer, LocalKind::Temp) => (b, a.into()), + (LocalKind::Temp, _) => (a, b.into()), + }, + }; + + if most_packed_projection(self.tcx, &self.body.local_decls, place).is_some() { + trace!("cannot unify {place:?} which has packed field projections"); + return None; + } + + if !self.liveness_matrix.disjoint_rows(local, place.local) { + trace!("cannot unify {a:?} and {b:?} which have overlapping live ranges"); + return None; + } + + // FIXME(#112651): This can be removed afterwards. + let local_ty = self.body.local_decls[local].ty; + let place_ty = place.ty(&self.body.local_decls, self.tcx).ty; + if local_ty != place_ty { + trace!( + "cannot unify {a:?} and {b:?} which have different types due to subtyping ({local_ty:?} vs {place_ty:?})" + ); + return None; + } + + Some((local, place)) + } + + #[tracing::instrument(level = "trace", skip(self))] + fn remap_local(&mut self, local: Local, place: Place<'tcx>) { + self.remapped_locals[local] = Some(place); + + self.liveness_matrix.union_rows(local, place.local); + self.liveness_matrix.clear_row(local); + + // If the original local was unprojectable then this now also applies to + // the mapped local. + if self.unprojectable_locals.contains(local) { + debug_assert!(place.projection.is_empty()); + self.unprojectable_locals.insert(place.local); + } + } + + fn visit_aggregate_assign( + &mut self, + dest: Place<'tcx>, + project_field: impl Fn(TyCtxt<'tcx>, Place<'tcx>, FieldIdx, Ty<'tcx>) -> Place<'tcx>, + operands: &IndexVec>, + location: Location, + ) { + // Attempt to unify each field operand with the corresponding field in + // the destination place. + let mut candidates = vec![]; + for (idx, operand) in operands.iter_enumerated() { + let (Operand::Copy(src) | Operand::Move(src)) = *operand else { + continue; + }; + let Some(src) = src.as_local() else { + continue; + }; + let dest = project_field(self.tcx, dest, idx, self.body.local_decls[src].ty); + trace!("Attempting to unify {dest:?} and {src:?} at {location:?}"); + if let Some((local, place)) = self.can_unify_places(dest, src.into()) { + candidates.push((local, place)); + } + } + + // Do the actual remapping *after* checking for live range overlaps. + // This is necessary because the input operands necessarily have + // overlapping live ranges. + for (local, place) in candidates { + self.remap_local(local, place); + } + } +} + +/// Since we are replacing all uses of a local with another place, we need to +/// ensure that the projections on that place are stable no matter where it is +/// used in the body. Additional this local may be used in debuginfo, so ensure +/// that the projections are compatible with usage in debuginfo. +fn check_projections(place: Place<'_>) -> bool { + place.projection.iter().all(|elem| elem.is_stable_offset() && elem.can_use_in_debuginfo()) +} + +impl<'tcx> Visitor<'tcx> for PlaceUnification<'_, 'tcx> { + fn visit_assign(&mut self, dest: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + if !check_projections(*dest) { + return; + } + match rvalue { + Rvalue::Use(Operand::Copy(src) | Operand::Move(src), _) => { + if !check_projections(*src) { + return; + } + + trace!("Attempting to unify {dest:?} and {src:?} at {location:?}"); + if let Some((local, place)) = self.can_unify_places(*src, *dest) { + self.remap_local(local, place); + } + } + Rvalue::Aggregate(box aggregate_kind, operands) => match *aggregate_kind { + AggregateKind::Array(_) => self.visit_aggregate_assign( + *dest, + |tcx, place, field_idx, _field_ty| { + place.project_deeper( + &[PlaceElem::ConstantIndex { + offset: field_idx.as_u32().into(), + min_length: field_idx.as_u32() as u64 + 1, + from_end: false, + }], + tcx, + ) + }, + operands, + location, + ), + AggregateKind::Tuple => self.visit_aggregate_assign( + *dest, + |tcx, place, field_idx, field_ty| { + place.project_deeper(&[PlaceElem::Field(field_idx, field_ty)], tcx) + }, + operands, + location, + ), + AggregateKind::Adt(_, _, _, _, Some(union_field_idx)) => { + debug_assert_eq!(operands.len(), 1); + self.visit_aggregate_assign( + *dest, + |tcx, place, _, field_ty| { + place + .project_deeper(&[PlaceElem::Field(union_field_idx, field_ty)], tcx) + }, + operands, + location, + ) + } + AggregateKind::Adt(adt_did, var_idx, _, _, None) => { + let def = self.tcx.adt_def(adt_did); + if def.repr().simd() { + // MCP#838 banned projections into SIMD types. + return; + } + self.visit_aggregate_assign( + *dest, + |tcx, place, field_idx, field_ty| { + if def.is_enum() { + place.project_deeper( + &[ + PlaceElem::Downcast(None, var_idx), + PlaceElem::Field(field_idx, field_ty), + ], + tcx, + ) + } else { + place.project_deeper(&[PlaceElem::Field(field_idx, field_ty)], tcx) + } + }, + operands, + location, + ) + } + _ => {} + }, + _ => {} + }; + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Apply place mappings to the MIR body. + +/// Returns the set of locals whose storage needs to be rebuilt. +fn apply_mappings<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + remapped_locals: &IndexVec>>, +) { + let mut rewriter = PlaceUpdater { tcx, remapped_locals }; + rewriter.visit_body_preserves_cfg(body); +} + +struct PlaceUpdater<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + remapped_locals: &'a IndexVec>>, +} + +impl<'tcx> MutVisitor<'tcx> for PlaceUpdater<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, context: PlaceContext, location: Location) { + if let Some(new_place) = self.remapped_locals[*local] { + trace!("replacing {local:?} with {new_place:?} at {location:?} ({context:?})"); + *local = new_place.as_local().expect("mapped place shouldn't have projections"); + } + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(new_place) = self.remapped_locals[place.local] { + trace!("replacing {place:?} with {new_place:?} at {location:?} ({context:?})"); + *place = new_place.project_deeper(place.projection, self.tcx) + } + + // Only call visit_local for projections, not the base local. + if let Some(new_projection) = self.process_projection(&place.projection, location) { + place.projection = self.tcx().mk_place_elems(&new_projection); + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + // Remove *all* storage statements. These are rebuilt from liveness + // information later. Also, since we've preserved StorageDead in + // unwind paths until now, we will want to remove those since they + // hurt LLVM's codegen. + StatementKind::StorageDead(_) | StatementKind::StorageLive(_) => { + statement.make_nop(true); + return; + } + _ => {} + } + + self.super_statement(statement, location); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage reconstruction + +/// Helper function to split a critical edge if necessary. +fn get_or_split_edge<'tcx>( + patcher: &mut MirPatch<'tcx>, + body: &Body<'tcx>, + split_edges: &mut FxHashMap<(BasicBlock, BasicBlock), BasicBlock>, + pred: BasicBlock, + succ: BasicBlock, +) -> BasicBlock { + if let Some(&split_bb) = split_edges.get(&(pred, succ)) { + return split_bb; + } + let source_info = body.basic_blocks[pred].terminator().source_info; + let split_bb = patcher.new_block(BasicBlockData::new( + Some(Terminator { source_info, kind: TerminatorKind::Goto { target: succ } }), + body.basic_blocks[succ].is_cleanup, + )); + patcher.mutate_terminator(body, pred, |kind| { + kind.successors_mut(|t| { + if *t == succ { + *t = split_bb; + } + }); + }); + split_edges.insert((pred, succ), split_bb); + split_bb +} + +/// Don't insert storage statements in cleanup blocks and in unreachable blocks. +fn should_insert_storage<'tcx>(block_data: &BasicBlockData<'tcx>) -> bool { + !block_data.is_cleanup && !matches!(block_data.terminator().kind, TerminatorKind::Unreachable) +} + +/// Re-constructs storage statements for all locals. +fn reconstruct_storage<'tcx>( + body: &mut Body<'tcx>, + points: &DenseLocationMap, + liveness_matrix: &SparseIntervalMatrix, +) { + let mut patcher = MirPatch::new(body); + let mut split_edges: FxHashMap<(BasicBlock, BasicBlock), BasicBlock> = Default::default(); + + for local in body.local_decls.indices() { + // Arguments and return values don't use storage statements. + match body.local_kind(local) { + LocalKind::Arg | LocalKind::ReturnPointer => continue, + LocalKind::Temp => {} + } + + // Ignore dead locals. + let Some(row) = liveness_matrix.row(local) else { continue }; + if row.is_empty() { + continue; + } + + // Helper functions to emit storage statements in block predecessors and + // successors. + let mut emit_storage_live_in_preds = + |body: &mut Body<'tcx>, + patcher: &mut MirPatch<'tcx>, + local: Local, + block: BasicBlock| { + for &pred in &body.basic_blocks.predecessors()[block].clone() { + // If the local is live at any point in the predecessor's + // terminator then no StorageLive is needed. + let term_early = + SplitPointIndex::new(points.terminator(pred), SplitPointEffect::Early); + let term_late = + SplitPointIndex::new(points.terminator(pred), SplitPointEffect::Late); + if !row.intersects_range(term_early..=term_late) { + // The local must be live on at least one predecessor, + // so if this is the only one then there is nothing to + // do. + debug_assert!(body.basic_blocks.predecessors()[block].len() > 1); + + // If the predecessor block has multiple successors then + // we need to split the critical edge before inserting + // StorageLive, otherwise the local would end up live on + // paths where it is supposed to be dead. + let loc = if body.basic_blocks[pred].terminator().successors().count() > 1 { + get_or_split_edge(patcher, body, &mut split_edges, pred, block) + .start_location() + } else { + body.terminator_loc(pred) + }; + patcher.add_statement(loc, StatementKind::StorageLive(local)); + } + } + }; + let emit_storage_dead_in_succs = + |body: &mut Body<'tcx>, + patcher: &mut MirPatch<'tcx>, + local: Local, + block: BasicBlock| { + for succ in body.basic_blocks[block].terminator().successors() { + if !should_insert_storage(&body.basic_blocks[succ]) { + return; + } + + if !row.contains(SplitPointIndex::new( + points.entry_point(succ), + SplitPointEffect::Early, + )) { + // We don't care about critical edges here: if the local + // is already dead in the successor then it doesn't + // matter if we emit a redundant StorageDead. + + patcher.add_statement( + succ.start_location(), + StatementKind::StorageDead(local), + ); + } + } + }; + + // Iterate through the live range of the local and insert `StorageLive` + // and `StorageDead` at the points where it transitions from dead to + // live and vice versa. + // + // Note that the range here is an *inclusive range*. + for range in row.iter_intervals() { + let start_block = points.to_location(range.start.point()).block; + let end_block = points.to_location(range.last.point()).block; + + // If the live range starts at the `Early` point then it means that + // the value came from a predecessor block. A write from the first + // statement would happen at the `Late` point instead. + if should_insert_storage(&body.basic_blocks[start_block]) { + if range.start + == SplitPointIndex::new( + points.entry_point(start_block), + SplitPointEffect::Early, + ) + { + // If the local is dead at the end of any predecessor block then + // emit a `StorageLive` before the terminator. + emit_storage_live_in_preds(body, &mut patcher, local, start_block); + } else { + // Otherwise just add `StorageLive` before the statement that + // starts the live range. + patcher.add_statement( + points.to_location(range.start.point()), + StatementKind::StorageLive(local), + ); + } + } + + // The live range may span multiple blocks because + // `SparseIntervalMatrix` will coalesce adjacent ranges. If this + // happens then we need to repeat the start of block logic (see + // above) and end of block logic (see below) at each block boundary. + let mut current_block = start_block; + debug_assert!(start_block <= end_block); + while current_block != end_block { + if should_insert_storage(&body.basic_blocks[current_block]) { + emit_storage_dead_in_succs(body, &mut patcher, local, current_block); + } + current_block = BasicBlock::from_usize(current_block.index() + 1); + if should_insert_storage(&body.basic_blocks[current_block]) { + emit_storage_live_in_preds(body, &mut patcher, local, current_block); + } + } + + // We need to insert `StorageDead` after the last statement that + // uses a local. If this is a terminator then we need to instead + // insert it at the start of every successor block where the local + // is dead on entry. + if should_insert_storage(&body.basic_blocks[end_block]) { + if range.last.point() == points.terminator(end_block) { + emit_storage_dead_in_succs(body, &mut patcher, local, current_block); + } else { + // Don't emit StorageDead in cleanup blocks. + if !body.basic_blocks[end_block].is_cleanup { + patcher.add_statement( + points.to_location(range.last.point()).successor_within_block(), + StatementKind::StorageDead(local), + ); + } + } + } + } + } + + patcher.apply(body); +} + +//////////////////////////////////////////////////////////////////////////////// +// Aliasing assignment fixup +// +// MIR assignments currently do not allow source and destination to alias, so +// fix this in post-processing. + +fn apply_alias_fixup<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patcher = MirPatch::new(body); + let mut fixup = AliasFixup { tcx, local_decls: &body.local_decls, patcher: &mut patcher }; + for (block, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + fixup.visit_basic_block_data(block, data); + } + patcher.apply(body); +} + +fn places_alias<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &IndexVec>, + a: Place<'tcx>, + b: Place<'tcx>, +) -> bool { + // Indirect places don't overlap because we assume they didn't overlap in + // the input MIR. + if a.local != b.local || a.is_indirect_first_projection() || b.is_indirect_first_projection() { + return false; + } + + for ((prefix, elem_a), (_, elem_b)) in a.iter_projections().zip(b.iter_projections()) { + // Continue until we find the first mismatching projection. + if elem_a == elem_b { + continue; + } + + match (elem_a, elem_b) { + // Disjoint fields don't alias except if they are union fields. + (PlaceElem::Field(_, _), PlaceElem::Field(_, _)) => { + let ty = prefix.ty(local_decls, tcx).ty; + return ty.is_union(); + } + + // Disjoint slice elements don't alias. + ( + PlaceElem::ConstantIndex { offset: offset_a, min_length: _, from_end: from_end_a }, + PlaceElem::ConstantIndex { offset: offset_b, min_length: _, from_end: from_end_b }, + ) if from_end_a == from_end_b && offset_a != offset_b => { + return false; + } + + // Conservatively assume the places may alias. + _ => return true, + } + } + + // If the projections are identical *or* one is a prefix of the other then + // the places alias. + true +} + +struct AliasFixup<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a IndexVec>, + patcher: &'a mut MirPatch<'tcx>, +} + +impl<'tcx> AliasFixup<'_, 'tcx> { + fn isolate_rvalue_to_local( + &mut self, + rvalue: Rvalue<'tcx>, + source_info: SourceInfo, + location: Location, + ) -> Place<'tcx> { + let ty = rvalue.ty(self.local_decls, self.tcx); + let temp = Place::from(self.patcher.new_temp(ty, source_info.span)); + trace!("isolating {rvalue:?} to {temp:?} due to conflict"); + self.patcher.add_statement(location, StatementKind::StorageLive(temp.local)); + self.patcher.add_assign(location, Place::from(temp), rvalue); + self.patcher.add_statement( + location.successor_within_block(), + StatementKind::StorageDead(temp.local), + ); + temp + } + + fn visit_aggregate_assign( + &mut self, + dest: Place<'tcx>, + enum_variant: Option, + project_field: impl Fn(TyCtxt<'tcx>, Place<'tcx>, FieldIdx, Ty<'tcx>) -> Place<'tcx>, + operands: &IndexVec>, + source_info: SourceInfo, + location: Location, + ) { + // Fast path: if no operand alias the destination, we're done. + let has_any_alias = operands.iter().any(|op| match op { + Operand::Copy(src) | Operand::Move(src) => { + places_alias(self.tcx, self.local_decls, dest, *src) + } + Operand::Constant(_) | Operand::RuntimeChecks(_) => false, + }); + if !has_any_alias { + return; + } + + debug!("splitting aggregate assignment at {location:?}"); + + // Split into per-field assignments. + let mut assignments = vec![]; + for (idx, op) in operands.iter_enumerated() { + let field_ty = op.ty(self.local_decls, self.tcx); + let dest_field = project_field(self.tcx, dest, idx, field_ty); + + let emit_op = match op { + Operand::Copy(src) | Operand::Move(src) => { + if *src == dest_field { + // Skip identity assignments. + continue; + } else if places_alias(self.tcx, self.local_decls, dest, *src) { + // Partial alias: hoist the source to a temp first so + // the per-field write no longer overlaps the dest. + Operand::Move(self.isolate_rvalue_to_local( + Rvalue::Use(op.clone(), WithRetag::No), + source_info, + location, + )) + } else { + op.clone() + } + } + Operand::Constant(_) | Operand::RuntimeChecks(_) => op.clone(), + }; + assignments.push((dest_field, emit_op)); + } + + // Perform assignments *after* all aliasing fields have been read into + // temporary locals. + for (dest_field, emit_op) in assignments { + self.patcher.add_assign(location, dest_field, Rvalue::Use(emit_op, WithRetag::No)); + } + + // Delete the original aggregate assignment. + self.patcher.nop_statement(location); + + // For enum variants, set the discriminant after all field writes. + if let Some(variant_index) = enum_variant { + self.patcher.add_statement( + location, + StatementKind::SetDiscriminant { place: Box::new(dest), variant_index }, + ); + } + } +} + +impl<'tcx> MutVisitor<'tcx> for AliasFixup<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + // Fixup the MIR to remove aliasing assignments. + if let StatementKind::Assign(box (dest, rvalue)) = &mut statement.kind { + match *rvalue { + Rvalue::Use(Operand::Copy(src) | Operand::Move(src), with_retag) => { + if places_alias(self.tcx, self.local_decls, *dest, src) { + if src == *dest { + debug!("{:?} turned into self-assignment, deleting", location); + statement.make_nop(true); + } else { + let temp = self.isolate_rvalue_to_local( + rvalue.clone(), + statement.source_info, + location, + ); + *rvalue = Rvalue::Use(Operand::Move(temp), with_retag); + } + } + } + Rvalue::Aggregate(box AggregateKind::Array(_), ref mut operands) => self + .visit_aggregate_assign( + *dest, + None, + |tcx, place, field_idx, _field_ty| { + place.project_deeper( + &[PlaceElem::ConstantIndex { + offset: field_idx.as_u32().into(), + min_length: field_idx.as_u32() as u64 + 1, + from_end: false, + }], + tcx, + ) + }, + operands, + statement.source_info, + location, + ), + Rvalue::Aggregate(box AggregateKind::Tuple, ref mut operands) => self + .visit_aggregate_assign( + *dest, + None, + |tcx, place, field_idx, field_ty| { + place.project_deeper(&[PlaceElem::Field(field_idx, field_ty)], tcx) + }, + operands, + statement.source_info, + location, + ), + Rvalue::Aggregate( + box AggregateKind::Adt(_, _, _, _, Some(union_field_idx)), + ref mut operands, + ) => { + debug_assert_eq!(operands.len(), 1); + self.visit_aggregate_assign( + *dest, + None, + |tcx, place, _, field_ty| { + place + .project_deeper(&[PlaceElem::Field(union_field_idx, field_ty)], tcx) + }, + operands, + statement.source_info, + location, + ) + } + Rvalue::Aggregate( + box AggregateKind::Adt(adt_did, var_idx, _, _, None), + ref mut operands, + ) => { + let def = self.tcx.adt_def(adt_did); + if def.repr().simd() { + // MCP#838 banned projections into SIMD types. + return; + } + self.visit_aggregate_assign( + *dest, + def.is_enum().then_some(var_idx), + |tcx, place, field_idx, field_ty| { + if def.is_enum() { + place.project_deeper( + &[ + PlaceElem::Downcast(None, var_idx), + PlaceElem::Field(field_idx, field_ty), + ], + tcx, + ) + } else { + place.project_deeper(&[PlaceElem::Field(field_idx, field_ty)], tcx) + } + }, + operands, + statement.source_info, + location, + ) + } + + // For other rvalues, don't try to split them into components + // and instead just introduce a temporary if there is any + // aliasing + Rvalue::Aggregate(..) + | Rvalue::Repeat(..) + | Rvalue::Cast(..) + | Rvalue::CopyForDeref(..) + | Rvalue::WrapUnsafeBinder(..) => { + let mut overlaps_dest = false; + VisitPlacesWith(|place, _ctxt| { + if places_alias(self.tcx, self.local_decls, *dest, place) { + overlaps_dest = true; + } + }) + .visit_rvalue(rvalue, location); + if overlaps_dest { + let temp = self.isolate_rvalue_to_local( + rvalue.clone(), + statement.source_info, + location, + ); + *rvalue = Rvalue::Use(Operand::Move(temp), WithRetag::No); + } + } + + // These permit either cannot have aliasing, or allow it because + // they only operate on scalar backend types. + Rvalue::Use(Operand::Constant(..) | Operand::RuntimeChecks(..), _) + | Rvalue::Ref(..) + | Rvalue::ThreadLocalRef(..) + | Rvalue::BinaryOp(..) + | Rvalue::UnaryOp(..) + | Rvalue::Discriminant(..) + | Rvalue::RawPtr(..) => {} + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/patch.rs b/compiler/rustc_mir_transform/src/patch.rs index 015bae56cf57e..b8d25ab01ae14 100644 --- a/compiler/rustc_mir_transform/src/patch.rs +++ b/compiler/rustc_mir_transform/src/patch.rs @@ -215,6 +215,21 @@ impl<'tcx> MirPatch<'tcx> { self.term_patch_map.insert(block, new); } + /// Modifies the terminator of a block, reading the existing patch if one exists or + /// cloning from the body otherwise. + pub(crate) fn mutate_terminator( + &mut self, + body: &Body<'tcx>, + bb: BasicBlock, + f: impl FnOnce(&mut TerminatorKind<'tcx>), + ) { + let kind = self + .term_patch_map + .entry(bb) + .or_insert_with(|| body.basic_blocks[bb].terminator().kind.clone()); + f(kind); + } + /// Mark given statement to be replaced by a `Nop`. /// /// This method only works on statements from the initial body, and cannot be used to remove diff --git a/compiler/rustc_public/src/mir/body.rs b/compiler/rustc_public/src/mir/body.rs index 6aeed20b1f481..43f4148a64afd 100644 --- a/compiler/rustc_public/src/mir/body.rs +++ b/compiler/rustc_public/src/mir/body.rs @@ -499,9 +499,6 @@ pub enum Rvalue { /// This is needed because dataflow analysis needs to distinguish /// `dest = Foo { x: ..., y: ... }` from `dest.x = ...; dest.y = ...;` in the case that `Foo` /// has a destructor. - /// - /// Disallowed after deaggregation for all aggregate kinds except `Array` and `Coroutine`. After - /// coroutine lowering, `Coroutine` aggregate kinds are disallowed too. Aggregate(AggregateKind, Vec), /// * `Offset` has the same semantics as `<*const T>::offset`, except that the second