Skip to content

Commit 0b7448f

Browse files
committed
Basic generators work
1 parent ba499c7 commit 0b7448f

File tree

4 files changed

+98
-34
lines changed

4 files changed

+98
-34
lines changed

compiler/rustc_middle/src/mir/terminator.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,14 @@ impl<O> AssertKind<O> {
142142
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
143143
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
144144
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
145-
bug!("`gen fn` should just keep returning `None` after the first time")
145+
"`gen fn` should just keep returning `None` after completion"
146146
}
147147
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
148148
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
149149
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
150-
bug!("`gen fn` should just keep returning `None` after panicking")
150+
"`gen fn` should just keep returning `None` after panicking"
151151
}
152+
152153
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
153154
bug!("Unexpected AssertKind")
154155
}

compiler/rustc_mir_transform/src/coroutine.rs

+64-32
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
224224

225225
struct TransformVisitor<'tcx> {
226226
tcx: TyCtxt<'tcx>,
227-
is_async_kind: bool,
227+
coroutine_kind: hir::CoroutineKind,
228228
state_adt_ref: AdtDef<'tcx>,
229229
state_args: GenericArgsRef<'tcx>,
230230

@@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> {
261261
is_return: bool,
262262
statements: &mut Vec<Statement<'tcx>>,
263263
) {
264-
let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
265-
(true, false) => 1, // CoroutineState::Complete
266-
(false, false) => 0, // CoroutineState::Yielded
267-
(true, true) => 0, // Poll::Ready
268-
(false, true) => 1, // Poll::Pending
264+
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
265+
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
266+
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
267+
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
268+
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
269+
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
270+
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
269271
});
270272

271273
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
272274

273-
// `Poll::Pending`
274-
if self.is_async_kind && idx == VariantIdx::new(1) {
275-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
275+
match self.coroutine_kind {
276+
// `Poll::Pending`
277+
CoroutineKind::Async(_) => {
278+
if !is_return {
279+
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
276280

277-
// FIXME(swatinem): assert that `val` is indeed unit?
278-
statements.push(Statement {
279-
kind: StatementKind::Assign(Box::new((
280-
Place::return_place(),
281-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
282-
))),
283-
source_info,
284-
});
285-
return;
281+
// FIXME(swatinem): assert that `val` is indeed unit?
282+
statements.push(Statement {
283+
kind: StatementKind::Assign(Box::new((
284+
Place::return_place(),
285+
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
286+
))),
287+
source_info,
288+
});
289+
return;
290+
}
291+
}
292+
// `Option::None`
293+
CoroutineKind::Gen(_) => {
294+
if is_return {
295+
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
296+
297+
statements.push(Statement {
298+
kind: StatementKind::Assign(Box::new((
299+
Place::return_place(),
300+
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
301+
))),
302+
source_info,
303+
});
304+
return;
305+
}
306+
}
307+
CoroutineKind::Coroutine => {}
286308
}
287309

288-
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
310+
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
289311
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
290312

291313
statements.push(Statement {
@@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
14391461
};
14401462

14411463
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
1442-
let (state_adt_ref, state_args) = if is_async_kind {
1443-
// Compute Poll<return_ty>
1444-
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1445-
let poll_adt_ref = tcx.adt_def(poll_did);
1446-
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
1447-
(poll_adt_ref, poll_args)
1448-
} else {
1449-
// Compute CoroutineState<yield_ty, return_ty>
1450-
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
1451-
let state_adt_ref = tcx.adt_def(state_did);
1452-
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
1453-
(state_adt_ref, state_args)
1464+
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
1465+
CoroutineKind::Async(_) => {
1466+
// Compute Poll<return_ty>
1467+
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1468+
let poll_adt_ref = tcx.adt_def(poll_did);
1469+
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
1470+
(poll_adt_ref, poll_args)
1471+
}
1472+
CoroutineKind::Gen(_) => {
1473+
// Compute Option<yield_ty>
1474+
let option_did = tcx.require_lang_item(LangItem::Option, None);
1475+
let option_adt_ref = tcx.adt_def(option_did);
1476+
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
1477+
(option_adt_ref, option_args)
1478+
}
1479+
CoroutineKind::Coroutine => {
1480+
// Compute CoroutineState<yield_ty, return_ty>
1481+
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
1482+
let state_adt_ref = tcx.adt_def(state_did);
1483+
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
1484+
(state_adt_ref, state_args)
1485+
}
14541486
};
14551487
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
14561488

@@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15181550
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
15191551
let mut transform = TransformVisitor {
15201552
tcx,
1521-
is_async_kind,
1553+
coroutine_kind: body.coroutine_kind().unwrap(),
15221554
state_adt_ref,
15231555
state_args,
15241556
remap,

compiler/rustc_ty_utils/src/instance.rs

+13
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,19 @@ fn resolve_associated_item<'tcx>(
258258
debug_assert!(tcx.defaultness(trait_item_id).has_value());
259259
Some(Instance::new(trait_item_id, rcvr_args))
260260
}
261+
} else if Some(trait_ref.def_id) == lang_items.iterator_trait() {
262+
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
263+
bug!()
264+
};
265+
if Some(trait_item_id) == tcx.lang_items().next_fn() {
266+
// `Iterator::next` is generated by the compiler.
267+
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
268+
} else {
269+
// All other methods are default methods of the `Iterator` trait.
270+
// (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`)
271+
debug_assert!(tcx.defaultness(trait_item_id).has_value());
272+
Some(Instance::new(trait_item_id, rcvr_args))
273+
}
261274
} else if Some(trait_ref.def_id) == lang_items.gen_trait() {
262275
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
263276
bug!()
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// revisions: next old
2+
//compile-flags: --edition 2024 -Zunstable-options
3+
//[next] compile-flags: -Ztrait-solver=next
4+
// run-pass
5+
#![feature(coroutines)]
6+
7+
fn foo() -> impl Iterator<Item = u32> {
8+
gen { yield 42; for x in 3..6 { yield x } }
9+
}
10+
11+
fn main() {
12+
let mut iter = foo();
13+
assert_eq!(iter.next(), Some(42));
14+
assert_eq!(iter.next(), Some(3));
15+
assert_eq!(iter.next(), Some(4));
16+
assert_eq!(iter.next(), Some(5));
17+
assert_eq!(iter.next(), None);
18+
}

0 commit comments

Comments
 (0)