Skip to content

Commit 724d63a

Browse files
Fix capture analysis for by-move closure bodies
1 parent dd5e502 commit 724d63a

File tree

5 files changed

+239
-16
lines changed

5 files changed

+239
-16
lines changed

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

+50-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
//! be a coroutine body that takes all of its upvars by-move, and which we stash
44
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
55
6+
use itertools::Itertools;
7+
68
use rustc_data_structures::fx::FxIndexSet;
79
use rustc_hir as hir;
810
use rustc_middle::mir::visit::MutVisitor;
@@ -26,36 +28,68 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
2628
if coroutine_ty.references_error() {
2729
return;
2830
}
31+
2932
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
33+
let args = args.as_coroutine();
3034

31-
let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
35+
let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
3236
if coroutine_kind == ty::ClosureKind::FnOnce {
3337
return;
3438
}
3539

36-
let mut by_ref_fields = FxIndexSet::default();
37-
let by_move_upvars = Ty::new_tup_from_iter(
38-
tcx,
39-
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
40-
if capture.is_by_ref() {
41-
by_ref_fields.insert(FieldIdx::from_usize(idx));
42-
}
43-
capture.place.ty()
44-
}),
40+
let parent_def_id = tcx.local_parent(coroutine_def_id);
41+
let ty::CoroutineClosure(_, parent_args) =
42+
*tcx.type_of(parent_def_id).instantiate_identity().kind()
43+
else {
44+
bug!();
45+
};
46+
let parent_args = parent_args.as_coroutine_closure();
47+
let parent_upvars_ty = parent_args.tupled_upvars_ty();
48+
let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased(
49+
parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty),
4550
);
51+
let num_args = tupled_inputs_ty.tuple_fields().len();
52+
53+
let mut by_ref_fields = FxIndexSet::default();
54+
for (idx, (coroutine_capture, parent_capture)) in tcx
55+
.closure_captures(coroutine_def_id)
56+
.iter()
57+
// By construction we capture all the args first.
58+
.skip(num_args)
59+
.zip_eq(tcx.closure_captures(parent_def_id))
60+
.enumerate()
61+
{
62+
// This argument is captured by-move from the parent closure, but by-ref
63+
// from the inner async block. That means that it's being borrowed from
64+
// the closure body -- we need to change the coroutine take it by move.
65+
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
66+
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
67+
}
68+
69+
// Make sure we're actually talking about the same capture.
70+
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
71+
}
72+
4673
let by_move_coroutine_ty = Ty::new_coroutine(
4774
tcx,
4875
coroutine_def_id.to_def_id(),
4976
ty::CoroutineArgs::new(
5077
tcx,
5178
ty::CoroutineArgsParts {
52-
parent_args: args.as_coroutine().parent_args(),
79+
parent_args: args.parent_args(),
5380
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
54-
resume_ty: args.as_coroutine().resume_ty(),
55-
yield_ty: args.as_coroutine().yield_ty(),
56-
return_ty: args.as_coroutine().return_ty(),
57-
witness: args.as_coroutine().witness(),
58-
tupled_upvars_ty: by_move_upvars,
81+
resume_ty: args.resume_ty(),
82+
yield_ty: args.yield_ty(),
83+
return_ty: args.return_ty(),
84+
witness: args.witness(),
85+
// Concatenate the args + closure's captures (since they're all by move).
86+
tupled_upvars_ty: Ty::new_tup_from_iter(
87+
tcx,
88+
tupled_inputs_ty
89+
.tuple_fields()
90+
.iter()
91+
.chain(parent_upvars_ty.tuple_fields()),
92+
),
5993
},
6094
)
6195
.args,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#![feature(async_closure, noop_waker)]
2+
3+
use std::future::Future;
4+
use std::pin::pin;
5+
use std::task::*;
6+
7+
pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
8+
let mut fut = pin!(fut);
9+
let ctx = &mut Context::from_waker(Waker::noop());
10+
11+
loop {
12+
match fut.as_mut().poll(ctx) {
13+
Poll::Pending => {}
14+
Poll::Ready(t) => break t,
15+
}
16+
}
17+
}
18+
19+
fn main() {
20+
block_on(async_main());
21+
}
22+
23+
async fn call<T>(f: &impl async Fn() -> T) -> T {
24+
f().await
25+
}
26+
27+
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
28+
f().await
29+
}
30+
31+
#[derive(Debug)]
32+
#[allow(unused)]
33+
struct Hello(i32);
34+
35+
async fn async_main() {
36+
// Capture something by-ref
37+
{
38+
let x = Hello(0);
39+
let c = async || {
40+
println!("{x:?}");
41+
};
42+
call(&c).await;
43+
call_once(c).await;
44+
45+
let x = &Hello(1);
46+
let c = async || {
47+
println!("{x:?}");
48+
};
49+
call(&c).await;
50+
call_once(c).await;
51+
}
52+
53+
// Capture something and consume it (force to `AsyncFnOnce`)
54+
{
55+
let x = Hello(2);
56+
let c = async || {
57+
println!("{x:?}");
58+
drop(x);
59+
};
60+
call_once(c).await;
61+
}
62+
63+
// Capture something with `move`, don't consume it
64+
{
65+
let x = Hello(3);
66+
let c = async move || {
67+
println!("{x:?}");
68+
};
69+
call(&c).await;
70+
call_once(c).await;
71+
72+
let x = &Hello(4);
73+
let c = async move || {
74+
println!("{x:?}");
75+
};
76+
call(&c).await;
77+
call_once(c).await;
78+
}
79+
80+
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
81+
{
82+
let x = Hello(5);
83+
let c = async move || {
84+
println!("{x:?}");
85+
drop(x);
86+
};
87+
call_once(c).await;
88+
}
89+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Hello(0)
2+
Hello(0)
3+
Hello(1)
4+
Hello(1)
5+
Hello(2)
6+
Hello(3)
7+
Hello(3)
8+
Hello(4)
9+
Hello(4)
10+
Hello(5)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//@ aux-build:block-on.rs
2+
//@ edition:2021
3+
//@ run-pass
4+
//@ check-run-results
5+
6+
#![feature(async_closure)]
7+
8+
extern crate block_on;
9+
10+
fn main() {
11+
block_on::block_on(async_main());
12+
}
13+
14+
async fn call<T>(f: &impl async Fn() -> T) -> T {
15+
f().await
16+
}
17+
18+
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
19+
f().await
20+
}
21+
22+
#[derive(Debug)]
23+
#[allow(unused)]
24+
struct Hello(i32);
25+
26+
async fn async_main() {
27+
// Capture something by-ref
28+
{
29+
let x = Hello(0);
30+
let c = async || {
31+
println!("{x:?}");
32+
};
33+
call(&c).await;
34+
call_once(c).await;
35+
36+
let x = &Hello(1);
37+
let c = async || {
38+
println!("{x:?}");
39+
};
40+
call(&c).await;
41+
call_once(c).await;
42+
}
43+
44+
// Capture something and consume it (force to `AsyncFnOnce`)
45+
{
46+
let x = Hello(2);
47+
let c = async || {
48+
println!("{x:?}");
49+
drop(x);
50+
};
51+
call_once(c).await;
52+
}
53+
54+
// Capture something with `move`, don't consume it
55+
{
56+
let x = Hello(3);
57+
let c = async move || {
58+
println!("{x:?}");
59+
};
60+
call(&c).await;
61+
call_once(c).await;
62+
63+
let x = &Hello(4);
64+
let c = async move || {
65+
println!("{x:?}");
66+
};
67+
call(&c).await;
68+
call_once(c).await;
69+
}
70+
71+
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
72+
{
73+
let x = Hello(5);
74+
let c = async move || {
75+
println!("{x:?}");
76+
drop(x);
77+
};
78+
call_once(c).await;
79+
}
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Hello(0)
2+
Hello(0)
3+
Hello(1)
4+
Hello(1)
5+
Hello(2)
6+
Hello(3)
7+
Hello(3)
8+
Hello(4)
9+
Hello(4)
10+
Hello(5)

0 commit comments

Comments
 (0)