Skip to content

Initial UnsafePinned implementation [Part 2: Lowering] #139896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler/rustc_ast_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContex
pub mod visit;

/// The movability of a coroutine / closure literal:
/// whether a coroutine contains self-references, causing it to be `!Unpin`.
/// whether a coroutine contains self-references, causing it to be `![Unsafe]Unpin`.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Copy)]
#[cfg_attr(
feature = "nightly",
derive(Encodable_NoContext, Decodable_NoContext, HashStable_NoContext)
)]
pub enum Movability {
/// May contain self-references, `!Unpin`.
/// May contain self-references, `!Unpin + !UnsafeUnpin`.
Static,
/// Must not contain self-references, `Unpin`.
/// Must not contain self-references, `Unpin + UnsafeUnpin`.
Movable,
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3385,7 +3385,7 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
Some(3)
} else if string.starts_with("static") {
// `static` is 6 chars long
// This is used for `!Unpin` coroutines
// This is used for immovable (self-referential) coroutines
Some(6)
} else {
None
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_middle/src/mir/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub struct CoroutineSavedTy<'tcx> {
pub source_info: SourceInfo,
/// Whether the local should be ignored for trait bound computations.
pub ignore_for_traits: bool,
/// If this local is borrowed across a suspension point and thus is
/// "wrapped" in `UnsafePinned`. Always false for movable coroutines.
pub pinned: bool,
}

/// The layout of coroutine state.
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.coroutine_hidden_types(def_id)
}

fn coroutine_has_pinned_fields(self, def_id: DefId) -> Option<bool> {
self.coroutine_has_pinned_fields(def_id)
}

fn fn_sig(self, def_id: DefId) -> ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>> {
self.fn_sig(def_id)
}
Expand Down Expand Up @@ -734,6 +738,7 @@ bidirectional_lang_item_map! {
TransmuteTrait,
Tuple,
Unpin,
UnsafeUnpin,
Unsize,
// tidy-alphabetical-end
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,13 @@ impl<'tcx> TyCtxt<'tcx> {
))
}

/// True if the given coroutine has any pinned fields.
/// `None` if the coroutine is tainted by errors.
pub fn coroutine_has_pinned_fields(self, def_id: DefId) -> Option<bool> {
self.mir_coroutine_witnesses(def_id)
.map(|layout| layout.field_tys.iter().any(|ty| ty.pinned))
}

/// Expands the given impl trait type, stopping if the type is recursive.
#[instrument(skip(self), level = "debug", ret)]
pub fn try_expand_impl_trait_type(
Expand Down
146 changes: 146 additions & 0 deletions compiler/rustc_mir_dataflow/src/impls/coro_pinned_locals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use rustc_index::bit_set::DenseBitSet;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use tracing::debug;

use crate::{Analysis, GenKill};

#[derive(Clone)]
pub struct CoroutinePinnedLocals(pub Local);

impl CoroutinePinnedLocals {
fn transfer_function<'a>(&self, domain: &'a mut DenseBitSet<Local>) -> TransferFunction<'a> {
TransferFunction { local: self.0, trans: domain }
}
}

impl<'tcx> Analysis<'tcx> for CoroutinePinnedLocals {
type Domain = DenseBitSet<Local>;
const NAME: &'static str = "coro_pinned_locals";

fn bottom_value(&self, body: &Body<'tcx>) -> Self::Domain {
// bottom = unborrowed
DenseBitSet::new_empty(body.local_decls().len())
}

fn initialize_start_block(&self, _: &Body<'tcx>, _: &mut Self::Domain) {
// No locals are actively borrowing from other locals on function entry
}

fn apply_primary_statement_effect(
&mut self,
state: &mut Self::Domain,
statement: &Statement<'tcx>,
location: Location,
) {
self.transfer_function(state).visit_statement(statement, location);
}

fn apply_primary_terminator_effect<'mir>(
&mut self,
state: &mut Self::Domain,
terminator: &'mir Terminator<'tcx>,
location: Location,
) -> TerminatorEdges<'mir, 'tcx> {
self.transfer_function(state).visit_terminator(terminator, location);

terminator.edges()
}
}

/// A `Visitor` that defines the transfer function for `CoroutinePinnedLocals`.
pub(super) struct TransferFunction<'a> {
local: Local,
trans: &'a mut DenseBitSet<Local>,
}

impl<'tcx> Visitor<'tcx> for TransferFunction<'_> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
self.super_statement(statement, location);

if let StatementKind::StorageDead(local) = statement.kind {
debug!(for_ = ?self.local, KILL = ?local, ?statement, ?location);
self.trans.kill(local);
}
}

fn visit_assign(
&mut self,
assigned_place: &Place<'tcx>,
rvalue: &Rvalue<'tcx>,
location: Location,
) {
self.super_assign(assigned_place, rvalue, location);

match rvalue {
Rvalue::Ref(_, BorrowKind::Mut { .. } | BorrowKind::Shared, place)
| Rvalue::RawPtr(RawPtrKind::Const | RawPtrKind::Mut, place) => {
if (!place.is_indirect() && place.local == self.local)
|| self.trans.contains(place.local)
{
if assigned_place.is_indirect() {
debug!(for_ = ?self.local, GEN_ptr_indirect = ?assigned_place, borrowed_place = ?place, ?rvalue, ?location);
self.trans.gen_(self.local);
} else {
debug!(for_ = ?self.local, GEN_ptr_direct = ?assigned_place, borrowed_place = ?place, ?rvalue, ?location);
self.trans.gen_(assigned_place.local);
}
}
}

// fake pointers don't count
Rvalue::Ref(_, BorrowKind::Fake(_), _)
| Rvalue::RawPtr(RawPtrKind::FakeForPtrMetadata, _) => {}

Rvalue::Use(..)
| Rvalue::Repeat(..)
| Rvalue::ThreadLocalRef(..)
| Rvalue::Len(..)
| Rvalue::Cast(..)
| Rvalue::BinaryOp(..)
| Rvalue::NullaryOp(..)
| Rvalue::UnaryOp(..)
| Rvalue::Discriminant(..)
| Rvalue::Aggregate(..)
| Rvalue::ShallowInitBox(..)
| Rvalue::CopyForDeref(..)
| Rvalue::WrapUnsafeBinder(..) => {}
}
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
self.super_terminator(terminator, location);

match terminator.kind {
TerminatorKind::Drop { place: dropped_place, .. } => {
// Drop terminators may call custom drop glue (`Drop::drop`), which takes `&mut
// self` as a parameter. In the general case, a drop impl could launder that
// reference into the surrounding environment through a raw pointer, thus creating
// a valid `*mut` pointing to the dropped local. We are not yet willing to declare
// this particular case UB, so we must treat all dropped locals as mutably borrowed
// for now. See discussion on [#61069].
//
// [#61069]: https://github.com/rust-lang/rust/pull/61069
if !dropped_place.is_indirect() && dropped_place.local == self.local {
debug!(for_ = ?self.local, GEN_drop = ?dropped_place, ?terminator, ?location);
self.trans.gen_(self.local);
}
}

TerminatorKind::Goto { .. }
| TerminatorKind::SwitchInt { .. }
| TerminatorKind::UnwindResume
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
| TerminatorKind::Unreachable
| TerminatorKind::Call { .. }
| TerminatorKind::TailCall { .. }
| TerminatorKind::Assert { .. }
| TerminatorKind::Yield { .. }
| TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::InlineAsm { .. } => {}
}
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir_dataflow/src/impls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod borrowed_locals;
mod coro_pinned_locals;
mod initialized;
mod liveness;
mod storage_liveness;

pub use self::borrowed_locals::{MaybeBorrowedLocals, borrowed_locals};
pub use self::coro_pinned_locals::CoroutinePinnedLocals;
pub use self::initialized::{
EverInitializedPlaces, EverInitializedPlacesDomain, MaybeInitializedPlaces,
MaybeUninitializedPlaces, MaybeUninitializedPlacesDomain,
Expand Down
57 changes: 52 additions & 5 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ use rustc_middle::ty::{
};
use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::impls::{
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
always_storage_live_locals,
CoroutinePinnedLocals, MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage,
MaybeStorageLive, always_storage_live_locals,
};
use rustc_mir_dataflow::{Analysis, Results, ResultsVisitor};
use rustc_span::def_id::{DefId, LocalDefId};
Expand Down Expand Up @@ -639,6 +639,15 @@ struct LivenessInfo {
/// Parallel vec to the above with SourceInfo for each yield terminator.
source_info_at_suspension_points: Vec<SourceInfo>,

/// Coroutine saved locals that are borrowed across a suspension point.
/// This corresponds to locals that are "wrapped" with `UnsafePinned`.
///
/// Note that movable coroutines do not allow borrowing locals across
/// suspension points and thus will always have this set empty.
///
/// For more information, see [RFC 3467](https://rust-lang.github.io/rfcs/3467-unsafe-pinned.html).
saved_locals_borrowed_across_suspension_points: DenseBitSet<CoroutineSavedLocal>,

/// For every saved local, the set of other saved locals that are
/// storage-live at the same time as this local. We cannot overlap locals in
/// the layout which have conflicting storage.
Expand All @@ -657,6 +666,9 @@ struct LivenessInfo {
/// case none exist, the local is considered to be always live.
/// - a local has to be stored if it is either directly used after the
/// the suspend point, or if it is live and has been previously borrowed.
///
/// We also compute locals which are "pinned" (borrowed across a suspension point).
/// These are "wrapped" in `UnsafePinned` and have their niche opts disabled.
fn locals_live_across_suspend_points<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
Expand Down Expand Up @@ -686,10 +698,12 @@ fn locals_live_across_suspend_points<'tcx>(
let mut liveness =
MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")).into_results_cursor(body);

let mut pinned_locals_cache = IndexVec::from_fn_n(|_| None, body.local_decls.len());
let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
let mut live_locals_at_suspension_points = Vec::new();
let mut source_info_at_suspension_points = Vec::new();
let mut live_locals_at_any_suspension_point = DenseBitSet::new_empty(body.local_decls.len());
let mut pinned_locals = DenseBitSet::new_empty(body.local_decls.len());

for (block, data) in body.basic_blocks.iter_enumerated() {
if let TerminatorKind::Yield { .. } = data.terminator().kind {
Expand Down Expand Up @@ -729,6 +743,27 @@ fn locals_live_across_suspend_points<'tcx>(

debug!("loc = {:?}, live_locals = {:?}", loc, live_locals);

for live_local in live_locals.iter() {
let pinned_cursor = pinned_locals_cache[live_local].get_or_insert_with(|| {
CoroutinePinnedLocals(live_local)
.iterate_to_fixpoint(tcx, body, None)
.into_results_cursor(body)
});
pinned_cursor.seek_to_block_end(block);
let mut pinned_by = pinned_cursor.get().clone();
pinned_by.intersect(&live_locals);

if !pinned_by.is_empty() {
assert!(
!movable,
"local {live_local:?} of movable coro shouldn't be pinned, yet it is pinned by {pinned_by:?}"
);

debug!("{live_local:?} pinned by {pinned_by:?} in {block:?}");
pinned_locals.insert(live_local);
}
}

// Add the locals live at this suspension point to the set of locals which live across
// any suspension points
live_locals_at_any_suspension_point.union(&live_locals);
Expand All @@ -738,7 +773,8 @@ fn locals_live_across_suspend_points<'tcx>(
}
}

debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
debug!(?pinned_locals);
debug!(live_locals_anywhere = ?live_locals_at_any_suspension_point);
let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point);

// Renumber our liveness_map bitsets to include only the locals we are
Expand All @@ -748,6 +784,9 @@ fn locals_live_across_suspend_points<'tcx>(
.map(|live_here| saved_locals.renumber_bitset(live_here))
.collect();

let saved_locals_borrowed_across_suspension_points =
saved_locals.renumber_bitset(&pinned_locals);

let storage_conflicts = compute_storage_conflicts(
body,
&saved_locals,
Expand All @@ -759,6 +798,7 @@ fn locals_live_across_suspend_points<'tcx>(
saved_locals,
live_locals_at_suspension_points,
source_info_at_suspension_points,
saved_locals_borrowed_across_suspension_points,
storage_conflicts,
storage_liveness: storage_liveness_map,
}
Expand Down Expand Up @@ -931,6 +971,7 @@ fn compute_layout<'tcx>(
saved_locals,
live_locals_at_suspension_points,
source_info_at_suspension_points,
saved_locals_borrowed_across_suspension_points,
storage_conflicts,
storage_liveness,
} = liveness;
Expand Down Expand Up @@ -960,8 +1001,14 @@ fn compute_layout<'tcx>(
ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true,
_ => false,
};
let decl =
CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
let pinned = saved_locals_borrowed_across_suspension_points.contains(saved_local);

let decl = CoroutineSavedTy {
ty: decl.ty,
source_info: decl.source_info,
ignore_for_traits,
pinned,
};
debug!(?decl);

tys.push(decl);
Expand Down
Loading
Loading