From c36e8fcc3c128e31eaa643904c9b8b33d9a5c1a6 Mon Sep 17 00:00:00 2001
From: Yotam Ofek <yotam.ofek@gmail.com>
Date: Fri, 11 Apr 2025 14:26:26 +0000
Subject: [PATCH] In `rustc_mir_tranform`, iterate over index newtypes instead
 of ints

---
 compiler/rustc_index_macros/src/newtype.rs    |  7 +++
 compiler/rustc_mir_transform/src/coroutine.rs | 53 ++++++++-----------
 .../src/early_otherwise_branch.rs             |  3 +-
 .../rustc_mir_transform/src/match_branches.rs | 17 +++---
 .../src/multiple_return_terminators.rs        | 14 +++--
 compiler/rustc_mir_transform/src/validate.rs  |  5 +-
 6 files changed, 46 insertions(+), 53 deletions(-)

diff --git a/compiler/rustc_index_macros/src/newtype.rs b/compiler/rustc_index_macros/src/newtype.rs
index f0b58eabbff9a..eedbe630cf2c4 100644
--- a/compiler/rustc_index_macros/src/newtype.rs
+++ b/compiler/rustc_index_macros/src/newtype.rs
@@ -257,6 +257,13 @@ impl Parse for Newtype {
                 }
             }
 
+            impl std::ops::AddAssign<usize> for #name {
+                #[inline]
+                fn add_assign(&mut self, other: usize) {
+                    *self = *self + other;
+                }
+            }
+
             impl rustc_index::Idx for #name {
                 #[inline]
                 fn new(value: usize) -> Self {
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index 04d96f117072f..80c729d66b1ec 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -547,7 +547,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
 
     let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
 
-    for bb in START_BLOCK..body.basic_blocks.next_index() {
+    for bb in body.basic_blocks.indices() {
         let bb_data = &body[bb];
         if bb_data.is_cleanup {
             continue;
@@ -556,11 +556,11 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         match &bb_data.terminator().kind {
             TerminatorKind::Call { func, .. } => {
                 let func_ty = func.ty(body, tcx);
-                if let ty::FnDef(def_id, _) = *func_ty.kind() {
-                    if def_id == get_context_def_id {
-                        let local = eliminate_get_context_call(&mut body[bb]);
-                        replace_resume_ty_local(tcx, body, local, context_mut_ref);
-                    }
+                if let ty::FnDef(def_id, _) = *func_ty.kind()
+                    && def_id == get_context_def_id
+                {
+                    let local = eliminate_get_context_call(&mut body[bb]);
+                    replace_resume_ty_local(tcx, body, local, context_mut_ref);
                 }
             }
             TerminatorKind::Yield { resume_arg, .. } => {
@@ -1057,7 +1057,7 @@ fn insert_switch<'tcx>(
     let blocks = body.basic_blocks_mut().iter_mut();
 
     for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) {
-        *target = BasicBlock::new(target.index() + 1);
+        *target += 1;
     }
 }
 
@@ -1209,14 +1209,8 @@ fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::Typing
     }
 
     // If there's a return terminator the function may return.
-    for block in body.basic_blocks.iter() {
-        if let TerminatorKind::Return = block.terminator().kind {
-            return true;
-        }
-    }
-
+    body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
     // Otherwise the function can't return.
-    false
 }
 
 fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
@@ -1293,12 +1287,12 @@ fn create_coroutine_resume_function<'tcx>(
                         kind: TerminatorKind::Goto { target: poison_block },
                     };
                 }
-            } else if !block.is_cleanup {
+            } else if !block.is_cleanup
                 // Any terminators that *can* unwind but don't have an unwind target set are also
                 // pointed at our poisoning block (unless they're part of the cleanup path).
-                if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() {
-                    *unwind = UnwindAction::Cleanup(poison_block);
-                }
+                && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
+            {
+                *unwind = UnwindAction::Cleanup(poison_block);
             }
         }
     }
@@ -1340,12 +1334,14 @@ fn create_coroutine_resume_function<'tcx>(
     make_coroutine_state_argument_indirect(tcx, body);
 
     match transform.coroutine_kind {
+        CoroutineKind::Coroutine(_)
+        | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
+        {
+            make_coroutine_state_argument_pinned(tcx, body);
+        }
         // Iterator::next doesn't accept a pinned argument,
         // unlike for all other coroutine kinds.
         CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
-        _ => {
-            make_coroutine_state_argument_pinned(tcx, body);
-        }
     }
 
     // Make sure we remove dead blocks to remove
@@ -1408,8 +1404,7 @@ fn create_cases<'tcx>(
                 let mut statements = Vec::new();
 
                 // Create StorageLive instructions for locals with live storage
-                for i in 0..(body.local_decls.len()) {
-                    let l = Local::new(i);
+                for l in body.local_decls.indices() {
                     let needs_storage_live = point.storage_liveness.contains(l)
                         && !transform.remap.contains(l)
                         && !transform.always_live_locals.contains(l);
@@ -1535,15 +1530,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
         let coroutine_kind = body.coroutine_kind().unwrap();
 
         // Get the discriminant type and args which typeck computed
-        let (discr_ty, movable) = match *coroutine_ty.kind() {
-            ty::Coroutine(_, args) => {
-                let args = args.as_coroutine();
-                (args.discr_ty(tcx), coroutine_kind.movability() == hir::Movability::Movable)
-            }
-            _ => {
-                tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
-            }
+        let ty::Coroutine(_, args) = coroutine_ty.kind() else {
+            tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
         };
+        let discr_ty = args.as_coroutine().discr_ty(tcx);
 
         let new_ret_ty = match coroutine_kind {
             CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
@@ -1610,6 +1600,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
 
         let always_live_locals = always_storage_live_locals(body);
 
+        let movable = coroutine_kind.movability() == hir::Movability::Movable;
         let liveness_info =
             locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
 
diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
index 57f7893be1b8c..d49f5d9f9c385 100644
--- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
+++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
@@ -103,9 +103,8 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
         let mut should_cleanup = false;
 
         // Also consider newly generated bbs in the same pass
-        for i in 0..body.basic_blocks.len() {
+        for parent in body.basic_blocks.indices() {
             let bbs = &*body.basic_blocks;
-            let parent = BasicBlock::from_usize(i);
             let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
 
             trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs
index 0d9d0368d3729..5059837328e24 100644
--- a/compiler/rustc_mir_transform/src/match_branches.rs
+++ b/compiler/rustc_mir_transform/src/match_branches.rs
@@ -20,13 +20,11 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         let typing_env = body.typing_env(tcx);
         let mut should_cleanup = false;
-        for i in 0..body.basic_blocks.len() {
-            let bbs = &*body.basic_blocks;
-            let bb_idx = BasicBlock::from_usize(i);
-            match bbs[bb_idx].terminator().kind {
+        for bb_idx in body.basic_blocks.indices() {
+            match &body.basic_blocks[bb_idx].terminator().kind {
                 TerminatorKind::SwitchInt {
-                    discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)),
-                    ref targets,
+                    discr: Operand::Copy(_) | Operand::Move(_),
+                    targets,
                     ..
                     // We require that the possible target blocks don't contain this block.
                 } if !targets.all_targets().contains(&bb_idx) => {}
@@ -66,9 +64,10 @@ trait SimplifyMatch<'tcx> {
         typing_env: ty::TypingEnv<'tcx>,
     ) -> Option<()> {
         let bbs = &body.basic_blocks;
-        let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
-            TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
-            _ => unreachable!(),
+        let TerminatorKind::SwitchInt { discr, targets, .. } =
+            &bbs[switch_bb_idx].terminator().kind
+        else {
+            unreachable!();
         };
 
         let discr_ty = discr.ty(body.local_decls(), tcx);
diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
index c63bfdcee8559..f59b849e85c62 100644
--- a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
+++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
@@ -18,19 +18,17 @@ impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators {
         // find basic blocks with no statement and a return terminator
         let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len());
         let bbs = body.basic_blocks_mut();
-        for idx in bbs.indices() {
-            if bbs[idx].statements.is_empty()
-                && bbs[idx].terminator().kind == TerminatorKind::Return
-            {
+        for (idx, bb) in bbs.iter_enumerated() {
+            if bb.statements.is_empty() && bb.terminator().kind == TerminatorKind::Return {
                 bbs_simple_returns.insert(idx);
             }
         }
 
         for bb in bbs {
-            if let TerminatorKind::Goto { target } = bb.terminator().kind {
-                if bbs_simple_returns.contains(target) {
-                    bb.terminator_mut().kind = TerminatorKind::Return;
-                }
+            if let TerminatorKind::Goto { target } = bb.terminator().kind
+                && bbs_simple_returns.contains(target)
+            {
+                bb.terminator_mut().kind = TerminatorKind::Return;
             }
         }
 
diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs
index e7930f0a1e3f6..66fe3ef4141f5 100644
--- a/compiler/rustc_mir_transform/src/validate.rs
+++ b/compiler/rustc_mir_transform/src/validate.rs
@@ -221,12 +221,11 @@ impl<'a, 'tcx> CfgChecker<'a, 'tcx> {
 
         // Check for cycles
         let mut stack = FxHashSet::default();
-        for i in 0..parent.len() {
-            let mut bb = BasicBlock::from_usize(i);
+        for (mut bb, parent) in parent.iter_enumerated_mut() {
             stack.clear();
             stack.insert(bb);
             loop {
-                let Some(parent) = parent[bb].take() else { break };
+                let Some(parent) = parent.take() else { break };
                 let no_cycle = stack.insert(parent);
                 if !no_cycle {
                     self.fail(