Skip to content

Commit 85d9002

Browse files
Couple of random coroutine pass simplifications
1 parent 2fe50cd commit 85d9002

File tree

1 file changed

+15
-29
lines changed

1 file changed

+15
-29
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,20 +1417,18 @@ fn create_coroutine_resume_function<'tcx>(
14171417
cases.insert(0, (UNRESUMED, START_BLOCK));
14181418

14191419
// Panic when resumed on the returned or poisoned state
1420-
let coroutine_kind = body.coroutine_kind().unwrap();
1421-
14221420
if can_unwind {
14231421
cases.insert(
14241422
1,
1425-
(POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(coroutine_kind))),
1423+
(POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind))),
14261424
);
14271425
}
14281426

14291427
if can_return {
1430-
let block = match coroutine_kind {
1428+
let block = match transform.coroutine_kind {
14311429
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
14321430
| CoroutineKind::Coroutine(_) => {
1433-
insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
1431+
insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
14341432
}
14351433
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
14361434
| CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
@@ -1444,7 +1442,7 @@ fn create_coroutine_resume_function<'tcx>(
14441442

14451443
make_coroutine_state_argument_indirect(tcx, body);
14461444

1447-
match coroutine_kind {
1445+
match transform.coroutine_kind {
14481446
// Iterator::next doesn't accept a pinned argument,
14491447
// unlike for all other coroutine kinds.
14501448
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
@@ -1597,6 +1595,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15971595

15981596
// The first argument is the coroutine type passed by value
15991597
let coroutine_ty = body.local_decls.raw[1].ty;
1598+
let coroutine_kind = body.coroutine_kind().unwrap();
16001599

16011600
// Get the discriminant type and args which typeck computed
16021601
let (discr_ty, movable) = match *coroutine_ty.kind() {
@@ -1613,19 +1612,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16131612
}
16141613
};
16151614

1616-
let is_async_kind = matches!(
1617-
body.coroutine_kind(),
1618-
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
1619-
);
1620-
let is_async_gen_kind = matches!(
1621-
body.coroutine_kind(),
1622-
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
1623-
);
1624-
let is_gen_kind = matches!(
1625-
body.coroutine_kind(),
1626-
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
1627-
);
1628-
let new_ret_ty = match body.coroutine_kind().unwrap() {
1615+
let new_ret_ty = match coroutine_kind {
16291616
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
16301617
// Compute Poll<return_ty>
16311618
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
@@ -1658,7 +1645,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16581645
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
16591646

16601647
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1661-
if is_async_kind || is_async_gen_kind {
1648+
if matches!(
1649+
coroutine_kind,
1650+
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
1651+
) {
16621652
transform_async_context(tcx, body);
16631653
}
16641654

@@ -1667,11 +1657,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16671657
// case there is no `Assign` to it that the transform can turn into a store to the coroutine
16681658
// state. After the yield the slot in the coroutine state would then be uninitialized.
16691659
let resume_local = Local::new(2);
1670-
let resume_ty = if is_async_kind {
1671-
Ty::new_task_context(tcx)
1672-
} else {
1673-
body.local_decls[resume_local].ty
1674-
};
1660+
let resume_ty = body.local_decls[resume_local].ty;
16751661
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
16761662

16771663
// When first entering the coroutine, move the resume argument into its old local
@@ -1714,11 +1700,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
17141700
// Run the transformation which converts Places from Local to coroutine struct
17151701
// accesses for locals in `remap`.
17161702
// It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
1717-
// either CoroutineState::Complete(x) and CoroutineState::Yielded(y),
1718-
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
1703+
// either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
1704+
// or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind.
17191705
let mut transform = TransformVisitor {
17201706
tcx,
1721-
coroutine_kind: body.coroutine_kind().unwrap(),
1707+
coroutine_kind,
17221708
remap,
17231709
storage_liveness,
17241710
always_live_locals,
@@ -1735,7 +1721,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
17351721
body.spread_arg = None;
17361722

17371723
// Remove the context argument within generator bodies.
1738-
if is_gen_kind {
1724+
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
17391725
transform_gen_context(tcx, body);
17401726
}
17411727

0 commit comments

Comments
 (0)