Skip to content

In rustc_mir_transform, iterate over index newtypes instead of ints #139674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compiler/rustc_index_macros/src/newtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ impl Parse for Newtype {
}
}

impl std::ops::AddAssign<usize> for #name {
#[inline]
fn add_assign(&mut self, other: usize) {
*self = *self + other;
}
}

impl rustc_index::Idx for #name {
#[inline]
fn new(value: usize) -> Self {
Expand Down
53 changes: 22 additions & 31 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);

for bb in START_BLOCK..body.basic_blocks.next_index() {
for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
Expand All @@ -556,11 +556,11 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind() {
if def_id == get_context_def_id {
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
Expand Down Expand Up @@ -1057,7 +1057,7 @@ fn insert_switch<'tcx>(
let blocks = body.basic_blocks_mut().iter_mut();

for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be changed into a count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loops over a bunch of BasicBlocks and increments each one, so something like *target += blocks.flat_map(..).count() would mean something completely different.
Unless I misunderstood you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'm blind, lol. I thought we were incrementing a counter, not incrementing a per-bb index.

*target = BasicBlock::new(target.index() + 1);
*target += 1;
}
}

Expand Down Expand Up @@ -1209,14 +1209,8 @@ fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::Typing
}

// If there's a return terminator the function may return.
for block in body.basic_blocks.iter() {
if let TerminatorKind::Return = block.terminator().kind {
return true;
}
}

body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
// Otherwise the function can't return.
false
}

fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
Expand Down Expand Up @@ -1293,12 +1287,12 @@ fn create_coroutine_resume_function<'tcx>(
kind: TerminatorKind::Goto { target: poison_block },
};
}
} else if !block.is_cleanup {
} else if !block.is_cleanup
// Any terminators that *can* unwind but don't have an unwind target set are also
// pointed at our poisoning block (unless they're part of the cleanup path).
if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() {
*unwind = UnwindAction::Cleanup(poison_block);
}
&& let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
{
*unwind = UnwindAction::Cleanup(poison_block);
}
}
}
Expand Down Expand Up @@ -1340,12 +1334,14 @@ fn create_coroutine_resume_function<'tcx>(
make_coroutine_state_argument_indirect(tcx, body);

match transform.coroutine_kind {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this back to a match and make it exhaustive, please.

!matches IMO makes it much easier to forget adding a variant in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this? It can be be even more exhaustive (i.e. on Movability and CoroutineSource)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good enough

CoroutineKind::Coroutine(_)
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
{
make_coroutine_state_argument_pinned(tcx, body);
}
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
_ => {
make_coroutine_state_argument_pinned(tcx, body);
}
}

// Make sure we remove dead blocks to remove
Expand Down Expand Up @@ -1408,8 +1404,7 @@ fn create_cases<'tcx>(
let mut statements = Vec::new();

// Create StorageLive instructions for locals with live storage
for i in 0..(body.local_decls.len()) {
let l = Local::new(i);
for l in body.local_decls.indices() {
let needs_storage_live = point.storage_liveness.contains(l)
&& !transform.remap.contains(l)
&& !transform.always_live_locals.contains(l);
Expand Down Expand Up @@ -1535,15 +1530,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
let coroutine_kind = body.coroutine_kind().unwrap();

// Get the discriminant type and args which typeck computed
let (discr_ty, movable) = match *coroutine_ty.kind() {
ty::Coroutine(_, args) => {
let args = args.as_coroutine();
(args.discr_ty(tcx), coroutine_kind.movability() == hir::Movability::Movable)
}
_ => {
tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
}
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
};
let discr_ty = args.as_coroutine().discr_ty(tcx);

let new_ret_ty = match coroutine_kind {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
Expand Down Expand Up @@ -1610,6 +1600,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {

let always_live_locals = always_storage_live_locals(body);

let movable = coroutine_kind.movability() == hir::Movability::Movable;
let liveness_info =
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);

Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
let mut should_cleanup = false;

// Also consider newly generated bbs in the same pass
for i in 0..body.basic_blocks.len() {
for parent in body.basic_blocks.indices() {
let bbs = &*body.basic_blocks;
let parent = BasicBlock::from_usize(i);
let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };

trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
Expand Down
17 changes: 8 additions & 9 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let typing_env = body.typing_env(tcx);
let mut should_cleanup = false;
for i in 0..body.basic_blocks.len() {
let bbs = &*body.basic_blocks;
let bb_idx = BasicBlock::from_usize(i);
match bbs[bb_idx].terminator().kind {
for bb_idx in body.basic_blocks.indices() {
match &body.basic_blocks[bb_idx].terminator().kind {
TerminatorKind::SwitchInt {
discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)),
ref targets,
discr: Operand::Copy(_) | Operand::Move(_),
targets,
..
// We require that the possible target blocks don't contain this block.
} if !targets.all_targets().contains(&bb_idx) => {}
Expand Down Expand Up @@ -66,9 +64,10 @@ trait SimplifyMatch<'tcx> {
typing_env: ty::TypingEnv<'tcx>,
) -> Option<()> {
let bbs = &body.basic_blocks;
let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
_ => unreachable!(),
let TerminatorKind::SwitchInt { discr, targets, .. } =
&bbs[switch_bb_idx].terminator().kind
else {
unreachable!();
};

let discr_ty = discr.ty(body.local_decls(), tcx);
Expand Down
14 changes: 6 additions & 8 deletions compiler/rustc_mir_transform/src/multiple_return_terminators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@ impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators {
// find basic blocks with no statement and a return terminator
let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len());
let bbs = body.basic_blocks_mut();
for idx in bbs.indices() {
if bbs[idx].statements.is_empty()
&& bbs[idx].terminator().kind == TerminatorKind::Return
{
for (idx, bb) in bbs.iter_enumerated() {
if bb.statements.is_empty() && bb.terminator().kind == TerminatorKind::Return {
bbs_simple_returns.insert(idx);
}
}

for bb in bbs {
if let TerminatorKind::Goto { target } = bb.terminator().kind {
if bbs_simple_returns.contains(target) {
bb.terminator_mut().kind = TerminatorKind::Return;
}
if let TerminatorKind::Goto { target } = bb.terminator().kind
&& bbs_simple_returns.contains(target)
{
bb.terminator_mut().kind = TerminatorKind::Return;
}
}

Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_mir_transform/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,11 @@ impl<'a, 'tcx> CfgChecker<'a, 'tcx> {

// Check for cycles
let mut stack = FxHashSet::default();
for i in 0..parent.len() {
let mut bb = BasicBlock::from_usize(i);
for (mut bb, parent) in parent.iter_enumerated_mut() {
stack.clear();
stack.insert(bb);
loop {
let Some(parent) = parent[bb].take() else { break };
let Some(parent) = parent.take() else { break };
let no_cycle = stack.insert(parent);
if !no_cycle {
self.fail(
Expand Down
Loading