diff --git a/compiler/rustc_errors/src/lib.rs b/compiler/rustc_errors/src/lib.rs
index f5f7618285e1f..75bb0e8e7b43f 100644
--- a/compiler/rustc_errors/src/lib.rs
+++ b/compiler/rustc_errors/src/lib.rs
@@ -589,7 +589,8 @@ struct DiagCtxtInner {
     /// add more information). All stashed diagnostics must be emitted with
     /// `emit_stashed_diagnostics` by the time the `DiagCtxtInner` is dropped,
     /// otherwise an assertion failure will occur.
-    stashed_diagnostics: FxIndexMap<(Span, StashKey), (DiagInner, Option<ErrorGuaranteed>)>,
+    stashed_diagnostics:
+        FxIndexMap<StashKey, FxIndexMap<Span, (DiagInner, Option<ErrorGuaranteed>)>>,
 
     future_breakage_diagnostics: Vec<DiagInner>,
 
@@ -912,8 +913,12 @@ impl<'a> DiagCtxtHandle<'a> {
         // FIXME(Centril, #69537): Consider reintroducing panic on overwriting a stashed diagnostic
         // if/when we have a more robust macro-friendly replacement for `(span, key)` as a key.
         // See the PR for a discussion.
-        let key = (span.with_parent(None), key);
-        self.inner.borrow_mut().stashed_diagnostics.insert(key, (diag, guar));
+        self.inner
+            .borrow_mut()
+            .stashed_diagnostics
+            .entry(key)
+            .or_default()
+            .insert(span.with_parent(None), (diag, guar));
 
         guar
     }
@@ -922,9 +927,10 @@ impl<'a> DiagCtxtHandle<'a> {
     /// and [`StashKey`] as the key. Panics if the found diagnostic is an
     /// error.
     pub fn steal_non_err(self, span: Span, key: StashKey) -> Option<Diag<'a, ()>> {
-        let key = (span.with_parent(None), key);
         // FIXME(#120456) - is `swap_remove` correct?
-        let (diag, guar) = self.inner.borrow_mut().stashed_diagnostics.swap_remove(&key)?;
+        let (diag, guar) = self.inner.borrow_mut().stashed_diagnostics.get_mut(&key).and_then(
+            |stashed_diagnostics| stashed_diagnostics.swap_remove(&span.with_parent(None)),
+        )?;
         assert!(!diag.is_error());
         assert!(guar.is_none());
         Some(Diag::new_diagnostic(self, diag))
@@ -943,9 +949,10 @@ impl<'a> DiagCtxtHandle<'a> {
     where
         F: FnMut(&mut Diag<'_>),
     {
-        let key = (span.with_parent(None), key);
         // FIXME(#120456) - is `swap_remove` correct?
-        let err = self.inner.borrow_mut().stashed_diagnostics.swap_remove(&key);
+        let err = self.inner.borrow_mut().stashed_diagnostics.get_mut(&key).and_then(
+            |stashed_diagnostics| stashed_diagnostics.swap_remove(&span.with_parent(None)),
+        );
         err.map(|(err, guar)| {
             // The use of `::<ErrorGuaranteed>` is safe because level is `Level::Error`.
             assert_eq!(err.level, Error);
@@ -966,9 +973,10 @@ impl<'a> DiagCtxtHandle<'a> {
         key: StashKey,
         new_err: Diag<'_>,
     ) -> ErrorGuaranteed {
-        let key = (span.with_parent(None), key);
         // FIXME(#120456) - is `swap_remove` correct?
-        let old_err = self.inner.borrow_mut().stashed_diagnostics.swap_remove(&key);
+        let old_err = self.inner.borrow_mut().stashed_diagnostics.get_mut(&key).and_then(
+            |stashed_diagnostics| stashed_diagnostics.swap_remove(&span.with_parent(None)),
+        );
         match old_err {
             Some((old_err, guar)) => {
                 assert_eq!(old_err.level, Error);
@@ -983,7 +991,14 @@ impl<'a> DiagCtxtHandle<'a> {
     }
 
     pub fn has_stashed_diagnostic(&self, span: Span, key: StashKey) -> bool {
-        self.inner.borrow().stashed_diagnostics.get(&(span.with_parent(None), key)).is_some()
+        let inner = self.inner.borrow();
+        if let Some(stashed_diagnostics) = inner.stashed_diagnostics.get(&key)
+            && !stashed_diagnostics.is_empty()
+        {
+            stashed_diagnostics.contains_key(&span.with_parent(None))
+        } else {
+            false
+        }
     }
 
     /// Emit all stashed diagnostics.
@@ -997,7 +1012,11 @@ impl<'a> DiagCtxtHandle<'a> {
         let inner = self.inner.borrow();
         inner.err_guars.len()
             + inner.lint_err_guars.len()
-            + inner.stashed_diagnostics.values().filter(|(_diag, guar)| guar.is_some()).count()
+            + inner
+                .stashed_diagnostics
+                .values()
+                .map(|a| a.values().filter(|(_, guar)| guar.is_some()).count())
+                .sum::<usize>()
     }
 
     /// This excludes lint errors and delayed bugs. Unless absolutely
@@ -1486,16 +1505,18 @@ impl DiagCtxtInner {
     fn emit_stashed_diagnostics(&mut self) -> Option<ErrorGuaranteed> {
         let mut guar = None;
         let has_errors = !self.err_guars.is_empty();
-        for (_, (diag, _guar)) in std::mem::take(&mut self.stashed_diagnostics).into_iter() {
-            if !diag.is_error() {
-                // Unless they're forced, don't flush stashed warnings when
-                // there are errors, to avoid causing warning overload. The
-                // stash would've been stolen already if it were important.
-                if !diag.is_force_warn() && has_errors {
-                    continue;
+        for (_, stashed_diagnostics) in std::mem::take(&mut self.stashed_diagnostics).into_iter() {
+            for (_, (diag, _guar)) in stashed_diagnostics {
+                if !diag.is_error() {
+                    // Unless they're forced, don't flush stashed warnings when
+                    // there are errors, to avoid causing warning overload. The
+                    // stash would've been stolen already if it were important.
+                    if !diag.is_force_warn() && has_errors {
+                        continue;
+                    }
                 }
+                guar = guar.or(self.emit_diagnostic(diag, None));
             }
-            guar = guar.or(self.emit_diagnostic(diag, None));
         }
         guar
     }
@@ -1688,6 +1709,7 @@ impl DiagCtxtInner {
             if let Some((_diag, guar)) = self
                 .stashed_diagnostics
                 .values()
+                .flat_map(|stashed_diagnostics| stashed_diagnostics.values())
                 .find(|(diag, guar)| guar.is_some() && diag.is_lint.is_none())
             {
                 *guar
@@ -1700,13 +1722,9 @@ impl DiagCtxtInner {
     fn has_errors(&self) -> Option<ErrorGuaranteed> {
         self.err_guars.get(0).copied().or_else(|| self.lint_err_guars.get(0).copied()).or_else(
             || {
-                if let Some((_diag, guar)) =
-                    self.stashed_diagnostics.values().find(|(_diag, guar)| guar.is_some())
-                {
-                    *guar
-                } else {
-                    None
-                }
+                self.stashed_diagnostics.values().find_map(|stashed_diagnostics| {
+                    stashed_diagnostics.values().find_map(|(_, guar)| *guar)
+                })
             },
         )
     }