Skip to content

Commit e9c9d23

Browse files
Fix drop shim for AsyncFnOnce closure, AsyncFnMut shim for AsyncFn closure
1 parent a48ffb4 commit e9c9d23

35 files changed

+595
-67
lines changed

compiler/rustc_const_eval/src/interpret/terminator.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
542542
| ty::InstanceDef::ReifyShim(..)
543543
| ty::InstanceDef::ClosureOnceShim { .. }
544544
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
545-
| ty::InstanceDef::CoroutineByMoveShim { .. }
545+
| ty::InstanceDef::CoroutineKindShim { .. }
546546
| ty::InstanceDef::FnPtrShim(..)
547547
| ty::InstanceDef::DropGlue(..)
548548
| ty::InstanceDef::CloneShim(..)

compiler/rustc_middle/src/mir/mod.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ pub struct CoroutineInfo<'tcx> {
263263
/// The body of the coroutine, modified to take its upvars by move rather than by ref.
264264
///
265265
/// This is used by coroutine-closures, which must return a different flavor of coroutine
266-
/// when called using `AsyncFnOnce::call_once`. It is produced by the `ByMoveBody` which
266+
/// when called using `AsyncFnOnce::call_once`. It is produced by the `ByMoveBody` pass which
267267
/// is run right after building the initial MIR, and will only be populated for coroutines
268268
/// which come out of the async closure desugaring.
269269
///
@@ -272,6 +272,13 @@ pub struct CoroutineInfo<'tcx> {
272272
/// using `run_passes`.
273273
pub by_move_body: Option<Body<'tcx>>,
274274

275+
/// The body of the coroutine, modified to take its upvars by mutable ref rather than by
276+
/// immutable ref.
277+
///
278+
/// FIXME(async_closures): This is literally the same body as the parent body. Find a better
279+
/// way to represent the by-mut signature (or cap the closure-kind of the coroutine).
280+
pub by_mut_body: Option<Body<'tcx>>,
281+
275282
/// The layout of a coroutine. This field is populated after the state transform pass.
276283
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
277284

@@ -292,6 +299,7 @@ impl<'tcx> CoroutineInfo<'tcx> {
292299
yield_ty: Some(yield_ty),
293300
resume_ty: Some(resume_ty),
294301
by_move_body: None,
302+
by_mut_body: None,
295303
coroutine_drop: None,
296304
coroutine_layout: None,
297305
}
@@ -602,6 +610,14 @@ impl<'tcx> Body<'tcx> {
602610
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_drop.as_ref())
603611
}
604612

613+
pub fn coroutine_by_move_body(&self) -> Option<&Body<'tcx>> {
614+
self.coroutine.as_ref()?.by_move_body.as_ref()
615+
}
616+
617+
pub fn coroutine_by_mut_body(&self) -> Option<&Body<'tcx>> {
618+
self.coroutine.as_ref()?.by_mut_body.as_ref()
619+
}
620+
605621
#[inline]
606622
pub fn coroutine_kind(&self) -> Option<CoroutineKind> {
607623
self.coroutine.as_ref().map(|coroutine| coroutine.coroutine_kind)

compiler/rustc_middle/src/mir/mono.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> {
403403
| InstanceDef::Virtual(..)
404404
| InstanceDef::ClosureOnceShim { .. }
405405
| InstanceDef::ConstructCoroutineInClosureShim { .. }
406-
| InstanceDef::CoroutineByMoveShim { .. }
406+
| InstanceDef::CoroutineKindShim { .. }
407407
| InstanceDef::DropGlue(..)
408408
| InstanceDef::CloneShim(..)
409409
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/mir/visit.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ macro_rules! make_mir_visitor {
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348348
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
349-
ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } |
349+
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id, target_kind: _ } |
350350
ty::InstanceDef::DropGlue(_def_id, None) => {}
351351

352352
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

+13-13
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ pub enum InstanceDef<'tcx> {
102102
},
103103

104104
/// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
105-
/// is called on a coroutine-closure whose closure kind is not `FnOnce`. This
106-
/// will select the body that is produced by the `ByMoveBody` transform, and thus
105+
/// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
106+
/// similarly for `AsyncFnMut`.
107+
///
108+
/// This will select the body that is produced by the `ByMoveBody` transform, and thus
107109
/// take and use all of its upvars by-move rather than by-ref.
108-
CoroutineByMoveShim { coroutine_def_id: DefId },
110+
CoroutineKindShim { coroutine_def_id: DefId, target_kind: ty::ClosureKind },
109111

110112
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
111113
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
@@ -192,7 +194,7 @@ impl<'tcx> InstanceDef<'tcx> {
192194
coroutine_closure_def_id: def_id,
193195
target_kind: _,
194196
}
195-
| ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id }
197+
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id, target_kind: _ }
196198
| InstanceDef::DropGlue(def_id, _)
197199
| InstanceDef::CloneShim(def_id, _)
198200
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -213,7 +215,7 @@ impl<'tcx> InstanceDef<'tcx> {
213215
| InstanceDef::Intrinsic(..)
214216
| InstanceDef::ClosureOnceShim { .. }
215217
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
216-
| ty::InstanceDef::CoroutineByMoveShim { .. }
218+
| ty::InstanceDef::CoroutineKindShim { .. }
217219
| InstanceDef::DropGlue(..)
218220
| InstanceDef::CloneShim(..)
219221
| InstanceDef::FnPtrAddrShim(..) => None,
@@ -310,7 +312,7 @@ impl<'tcx> InstanceDef<'tcx> {
310312
| InstanceDef::DropGlue(_, Some(_)) => false,
311313
InstanceDef::ClosureOnceShim { .. }
312314
| InstanceDef::ConstructCoroutineInClosureShim { .. }
313-
| InstanceDef::CoroutineByMoveShim { .. }
315+
| InstanceDef::CoroutineKindShim { .. }
314316
| InstanceDef::DropGlue(..)
315317
| InstanceDef::Item(_)
316318
| InstanceDef::Intrinsic(..)
@@ -349,7 +351,7 @@ fn fmt_instance(
349351
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
350352
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
351353
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
352-
InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"),
354+
InstanceDef::CoroutineKindShim { .. } => write!(f, " - shim"),
353355
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
354356
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
355357
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
@@ -651,13 +653,11 @@ impl<'tcx> Instance<'tcx> {
651653
if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
652654
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
653655
} else {
654-
assert_eq!(
655-
args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
656-
ty::ClosureKind::FnOnce,
657-
"FIXME(async_closures): Generate a by-mut body here."
658-
);
659656
Some(Instance {
660-
def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
657+
def: ty::InstanceDef::CoroutineKindShim {
658+
coroutine_def_id,
659+
target_kind: args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
660+
},
661661
args,
662662
})
663663
}

compiler/rustc_middle/src/ty/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2355,7 +2355,7 @@ impl<'tcx> TyCtxt<'tcx> {
23552355
| ty::InstanceDef::Virtual(..)
23562356
| ty::InstanceDef::ClosureOnceShim { .. }
23572357
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
2358-
| ty::InstanceDef::CoroutineByMoveShim { .. }
2358+
| ty::InstanceDef::CoroutineKindShim { .. }
23592359
| ty::InstanceDef::DropGlue(..)
23602360
| ty::InstanceDef::CloneShim(..)
23612361
| ty::InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/ty/print/mod.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::ty::{self, Ty, TyCtxt};
33

44
use rustc_data_structures::fx::FxHashSet;
55
use rustc_data_structures::sso::SsoHashSet;
6+
use rustc_hir as hir;
67
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId};
78
use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData};
89

@@ -130,8 +131,24 @@ pub trait Printer<'tcx>: Sized {
130131
parent_args = &args[..generics.parent_count.min(args.len())];
131132

132133
match key.disambiguated_data.data {
133-
// Closures' own generics are only captures, don't print them.
134-
DefPathData::Closure => {}
134+
DefPathData::Closure => {
135+
// FIXME(async_closures): This is somewhat ugly.
136+
// We need to additionally print the `kind` field of a closure if
137+
// it is desugared from a coroutine-closure.
138+
if let Some(hir::CoroutineKind::Desugared(
139+
_,
140+
hir::CoroutineSource::Closure,
141+
)) = self.tcx().coroutine_kind(def_id)
142+
&& args.len() >= parent_args.len() + 1
143+
{
144+
return self.path_generic_args(
145+
|cx| cx.print_def_path(def_id, parent_args),
146+
&args[..parent_args.len() + 1][..1],
147+
);
148+
} else {
149+
// Closures' own generics are only captures, don't print them.
150+
}
151+
}
135152
// This covers both `DefKind::AnonConst` and `DefKind::InlineConst`.
136153
// Anon consts doesn't have their own generics, and inline consts' own
137154
// generics are their inferred types, so don't print them.

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
use rustc_data_structures::fx::FxIndexSet;
77
use rustc_hir as hir;
88
use rustc_middle::mir::visit::MutVisitor;
9-
use rustc_middle::mir::{self, MirPass};
9+
use rustc_middle::mir::{self, dump_mir, MirPass};
1010
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
1111
use rustc_target::abi::FieldIdx;
1212

@@ -24,7 +24,9 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
2424
};
2525
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
2626
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
27-
if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
27+
28+
let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
29+
if coroutine_kind == ty::ClosureKind::FnOnce {
2830
return;
2931
}
3032

@@ -58,14 +60,49 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
5860

5961
let mut by_move_body = body.clone();
6062
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
63+
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
6164
by_move_body.source = mir::MirSource {
62-
instance: InstanceDef::CoroutineByMoveShim {
65+
instance: InstanceDef::CoroutineKindShim {
6366
coroutine_def_id: coroutine_def_id.to_def_id(),
67+
target_kind: ty::ClosureKind::FnOnce,
6468
},
6569
promoted: None,
6670
};
67-
6871
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
72+
73+
// If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
74+
// This is actually just a copy of the by-ref body, but with a different self type.
75+
// FIXME(async_closures): We could probably unify this with the by-ref body somehow.
76+
if coroutine_kind == ty::ClosureKind::Fn {
77+
let by_mut_coroutine_ty = Ty::new_coroutine(
78+
tcx,
79+
coroutine_def_id.to_def_id(),
80+
ty::CoroutineArgs::new(
81+
tcx,
82+
ty::CoroutineArgsParts {
83+
parent_args: args.as_coroutine().parent_args(),
84+
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
85+
resume_ty: args.as_coroutine().resume_ty(),
86+
yield_ty: args.as_coroutine().yield_ty(),
87+
return_ty: args.as_coroutine().return_ty(),
88+
witness: args.as_coroutine().witness(),
89+
tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
90+
},
91+
)
92+
.args,
93+
);
94+
let mut by_mut_body = body.clone();
95+
by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
96+
dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
97+
by_mut_body.source = mir::MirSource {
98+
instance: InstanceDef::CoroutineKindShim {
99+
coroutine_def_id: coroutine_def_id.to_def_id(),
100+
target_kind: ty::ClosureKind::FnMut,
101+
},
102+
promoted: None,
103+
};
104+
body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
105+
}
69106
}
70107
}
71108

compiler/rustc_mir_transform/src/inline.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ impl<'tcx> Inliner<'tcx> {
318318
| InstanceDef::FnPtrShim(..)
319319
| InstanceDef::ClosureOnceShim { .. }
320320
| InstanceDef::ConstructCoroutineInClosureShim { .. }
321-
| InstanceDef::CoroutineByMoveShim { .. }
321+
| InstanceDef::CoroutineKindShim { .. }
322322
| InstanceDef::DropGlue(..)
323323
| InstanceDef::CloneShim(..)
324324
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_mir_transform/src/inline/cycle.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
8888
| InstanceDef::FnPtrShim(..)
8989
| InstanceDef::ClosureOnceShim { .. }
9090
| InstanceDef::ConstructCoroutineInClosureShim { .. }
91-
| InstanceDef::CoroutineByMoveShim { .. }
91+
| InstanceDef::CoroutineKindShim { .. }
9292
| InstanceDef::ThreadLocalShim { .. }
9393
| InstanceDef::CloneShim(..) => {}
9494

compiler/rustc_mir_transform/src/pass_manager.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,13 @@ fn run_passes_inner<'tcx>(
190190
body.pass_count = 1;
191191
}
192192

193-
if let Some(coroutine) = body.coroutine.as_mut()
194-
&& let Some(by_move_body) = coroutine.by_move_body.as_mut()
195-
{
196-
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
193+
if let Some(coroutine) = body.coroutine.as_mut() {
194+
if let Some(by_move_body) = coroutine.by_move_body.as_mut() {
195+
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
196+
}
197+
if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() {
198+
run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each);
199+
}
197200
}
198201
}
199202

0 commit comments

Comments
 (0)