3
3
//! be a coroutine body that takes all of its upvars by-move, and which we stash
4
4
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
5
5
6
+ use itertools:: Itertools ;
7
+
6
8
use rustc_data_structures:: fx:: FxIndexSet ;
7
9
use rustc_hir as hir;
8
10
use rustc_middle:: mir:: visit:: MutVisitor ;
@@ -26,36 +28,68 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
26
28
if coroutine_ty. references_error ( ) {
27
29
return ;
28
30
}
31
+
29
32
let ty:: Coroutine ( _, args) = * coroutine_ty. kind ( ) else { bug ! ( "{body:#?}" ) } ;
33
+ let args = args. as_coroutine ( ) ;
30
34
31
- let coroutine_kind = args. as_coroutine ( ) . kind_ty ( ) . to_opt_closure_kind ( ) . unwrap ( ) ;
35
+ let coroutine_kind = args. kind_ty ( ) . to_opt_closure_kind ( ) . unwrap ( ) ;
32
36
if coroutine_kind == ty:: ClosureKind :: FnOnce {
33
37
return ;
34
38
}
35
39
36
- let mut by_ref_fields = FxIndexSet :: default ( ) ;
37
- let by_move_upvars = Ty :: new_tup_from_iter (
38
- tcx,
39
- tcx. closure_captures ( coroutine_def_id) . iter ( ) . enumerate ( ) . map ( |( idx, capture) | {
40
- if capture. is_by_ref ( ) {
41
- by_ref_fields. insert ( FieldIdx :: from_usize ( idx) ) ;
42
- }
43
- capture. place . ty ( )
44
- } ) ,
40
+ let parent_def_id = tcx. local_parent ( coroutine_def_id) ;
41
+ let ty:: CoroutineClosure ( _, parent_args) =
42
+ * tcx. type_of ( parent_def_id) . instantiate_identity ( ) . kind ( )
43
+ else {
44
+ bug ! ( ) ;
45
+ } ;
46
+ let parent_args = parent_args. as_coroutine_closure ( ) ;
47
+ let parent_upvars_ty = parent_args. tupled_upvars_ty ( ) ;
48
+ let tupled_inputs_ty = tcx. instantiate_bound_regions_with_erased (
49
+ parent_args. coroutine_closure_sig ( ) . map_bound ( |sig| sig. tupled_inputs_ty ) ,
45
50
) ;
51
+ let num_args = tupled_inputs_ty. tuple_fields ( ) . len ( ) ;
52
+
53
+ let mut by_ref_fields = FxIndexSet :: default ( ) ;
54
+ for ( idx, ( coroutine_capture, parent_capture) ) in tcx
55
+ . closure_captures ( coroutine_def_id)
56
+ . iter ( )
57
+ // By construction we capture all the args first.
58
+ . skip ( num_args)
59
+ . zip_eq ( tcx. closure_captures ( parent_def_id) )
60
+ . enumerate ( )
61
+ {
62
+ // This argument is captured by-move from the parent closure, but by-ref
63
+ // from the inner async block. That means that it's being borrowed from
64
+ // the closure body -- we need to change the coroutine take it by move.
65
+ if coroutine_capture. is_by_ref ( ) && !parent_capture. is_by_ref ( ) {
66
+ by_ref_fields. insert ( FieldIdx :: from_usize ( num_args + idx) ) ;
67
+ }
68
+
69
+ // Make sure we're actually talking about the same capture.
70
+ assert_eq ! ( coroutine_capture. place. ty( ) , parent_capture. place. ty( ) ) ;
71
+ }
72
+
46
73
let by_move_coroutine_ty = Ty :: new_coroutine (
47
74
tcx,
48
75
coroutine_def_id. to_def_id ( ) ,
49
76
ty:: CoroutineArgs :: new (
50
77
tcx,
51
78
ty:: CoroutineArgsParts {
52
- parent_args : args. as_coroutine ( ) . parent_args ( ) ,
79
+ parent_args : args. parent_args ( ) ,
53
80
kind_ty : Ty :: from_closure_kind ( tcx, ty:: ClosureKind :: FnOnce ) ,
54
- resume_ty : args. as_coroutine ( ) . resume_ty ( ) ,
55
- yield_ty : args. as_coroutine ( ) . yield_ty ( ) ,
56
- return_ty : args. as_coroutine ( ) . return_ty ( ) ,
57
- witness : args. as_coroutine ( ) . witness ( ) ,
58
- tupled_upvars_ty : by_move_upvars,
81
+ resume_ty : args. resume_ty ( ) ,
82
+ yield_ty : args. yield_ty ( ) ,
83
+ return_ty : args. return_ty ( ) ,
84
+ witness : args. witness ( ) ,
85
+ // Concatenate the args + closure's captures (since they're all by move).
86
+ tupled_upvars_ty : Ty :: new_tup_from_iter (
87
+ tcx,
88
+ tupled_inputs_ty
89
+ . tuple_fields ( )
90
+ . iter ( )
91
+ . chain ( parent_upvars_ty. tuple_fields ( ) ) ,
92
+ ) ,
59
93
} ,
60
94
)
61
95
. args ,
0 commit comments