@@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
224
224
225
225
struct TransformVisitor < ' tcx > {
226
226
tcx : TyCtxt < ' tcx > ,
227
- is_async_kind : bool ,
227
+ coroutine_kind : hir :: CoroutineKind ,
228
228
state_adt_ref : AdtDef < ' tcx > ,
229
229
state_args : GenericArgsRef < ' tcx > ,
230
230
@@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> {
261
261
is_return : bool ,
262
262
statements : & mut Vec < Statement < ' tcx > > ,
263
263
) {
264
- let idx = VariantIdx :: new ( match ( is_return, self . is_async_kind ) {
265
- ( true , false ) => 1 , // CoroutineState::Complete
266
- ( false , false ) => 0 , // CoroutineState::Yielded
267
- ( true , true ) => 0 , // Poll::Ready
268
- ( false , true ) => 1 , // Poll::Pending
264
+ let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
265
+ ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
266
+ ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
267
+ ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
268
+ ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
269
+ ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
270
+ ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
269
271
} ) ;
270
272
271
273
let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
272
274
273
- // `Poll::Pending`
274
- if self . is_async_kind && idx == VariantIdx :: new ( 1 ) {
275
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
275
+ match self . coroutine_kind {
276
+ // `Poll::Pending`
277
+ CoroutineKind :: Async ( _) => {
278
+ if !is_return {
279
+ assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
276
280
277
- // FIXME(swatinem): assert that `val` is indeed unit?
278
- statements. push ( Statement {
279
- kind : StatementKind :: Assign ( Box :: new ( (
280
- Place :: return_place ( ) ,
281
- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
282
- ) ) ) ,
283
- source_info,
284
- } ) ;
285
- return ;
281
+ // FIXME(swatinem): assert that `val` is indeed unit?
282
+ statements. push ( Statement {
283
+ kind : StatementKind :: Assign ( Box :: new ( (
284
+ Place :: return_place ( ) ,
285
+ Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
286
+ ) ) ) ,
287
+ source_info,
288
+ } ) ;
289
+ return ;
290
+ }
291
+ }
292
+ // `Option::None`
293
+ CoroutineKind :: Gen ( _) => {
294
+ if is_return {
295
+ assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
296
+
297
+ statements. push ( Statement {
298
+ kind : StatementKind :: Assign ( Box :: new ( (
299
+ Place :: return_place ( ) ,
300
+ Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
301
+ ) ) ) ,
302
+ source_info,
303
+ } ) ;
304
+ return ;
305
+ }
306
+ }
307
+ CoroutineKind :: Coroutine => { }
286
308
}
287
309
288
- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
310
+ // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some (x)`
289
311
assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
290
312
291
313
statements. push ( Statement {
@@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1439
1461
} ;
1440
1462
1441
1463
let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
1442
- let ( state_adt_ref, state_args) = if is_async_kind {
1443
- // Compute Poll<return_ty>
1444
- let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1445
- let poll_adt_ref = tcx. adt_def ( poll_did) ;
1446
- let poll_args = tcx. mk_args ( & [ body. return_ty ( ) . into ( ) ] ) ;
1447
- ( poll_adt_ref, poll_args)
1448
- } else {
1449
- // Compute CoroutineState<yield_ty, return_ty>
1450
- let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1451
- let state_adt_ref = tcx. adt_def ( state_did) ;
1452
- let state_args = tcx. mk_args ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1453
- ( state_adt_ref, state_args)
1464
+ let ( state_adt_ref, state_args) = match body. coroutine_kind ( ) . unwrap ( ) {
1465
+ CoroutineKind :: Async ( _) => {
1466
+ // Compute Poll<return_ty>
1467
+ let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1468
+ let poll_adt_ref = tcx. adt_def ( poll_did) ;
1469
+ let poll_args = tcx. mk_args ( & [ body. return_ty ( ) . into ( ) ] ) ;
1470
+ ( poll_adt_ref, poll_args)
1471
+ }
1472
+ CoroutineKind :: Gen ( _) => {
1473
+ // Compute Option<yield_ty>
1474
+ let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
1475
+ let option_adt_ref = tcx. adt_def ( option_did) ;
1476
+ let option_args = tcx. mk_args ( & [ body. yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1477
+ ( option_adt_ref, option_args)
1478
+ }
1479
+ CoroutineKind :: Coroutine => {
1480
+ // Compute CoroutineState<yield_ty, return_ty>
1481
+ let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1482
+ let state_adt_ref = tcx. adt_def ( state_did) ;
1483
+ let state_args = tcx. mk_args ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1484
+ ( state_adt_ref, state_args)
1485
+ }
1454
1486
} ;
1455
1487
let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
1456
1488
@@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1518
1550
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
1519
1551
let mut transform = TransformVisitor {
1520
1552
tcx,
1521
- is_async_kind ,
1553
+ coroutine_kind : body . coroutine_kind ( ) . unwrap ( ) ,
1522
1554
state_adt_ref,
1523
1555
state_args,
1524
1556
remap,
0 commit comments