From 8684e9e47dcc52cc51dccdf1a74bac69deb38207 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Wed, 18 Nov 2020 16:53:39 +0100
Subject: [PATCH 1/7] Merge {get,ensure}_query.

---
 .../rustc_middle/src/ty/query/plumbing.rs     |  4 +-
 .../rustc_query_system/src/query/plumbing.rs  | 46 +++++++++----------
 2 files changed, 24 insertions(+), 26 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/query/plumbing.rs b/compiler/rustc_middle/src/ty/query/plumbing.rs
index 46addcdaead43..f6370452e80b2 100644
--- a/compiler/rustc_middle/src/ty/query/plumbing.rs
+++ b/compiler/rustc_middle/src/ty/query/plumbing.rs
@@ -401,7 +401,7 @@ macro_rules! define_queries {
             $($(#[$attr])*
             #[inline(always)]
             pub fn $name(self, key: query_helper_param_ty!($($K)*)) {
-                ensure_query::<queries::$name<'_>, _>(self.tcx, key.into_query_param())
+                get_query::<queries::$name<'_>, _>(self.tcx, DUMMY_SP, key.into_query_param(), QueryMode::Ensure);
             })*
         }
 
@@ -484,7 +484,7 @@ macro_rules! define_queries {
             pub fn $name(self, key: query_helper_param_ty!($($K)*))
                 -> <queries::$name<$tcx> as QueryConfig>::Stored
             {
-                get_query::<queries::$name<'_>, _>(self.tcx, self.span, key.into_query_param())
+                get_query::<queries::$name<'_>, _>(self.tcx, self.span, key.into_query_param(), QueryMode::Get).unwrap()
             })*
         }
 
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index cbbb449b4f8ab..f2ebf8d7d3d08 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -17,7 +17,6 @@ use rustc_data_structures::sharded::Sharded;
 use rustc_data_structures::sync::{Lock, LockGuard};
 use rustc_data_structures::thin_vec::ThinVec;
 use rustc_errors::{Diagnostic, FatalError};
-use rustc_span::source_map::DUMMY_SP;
 use rustc_span::Span;
 use std::collections::hash_map::Entry;
 use std::fmt::Debug;
@@ -641,31 +640,26 @@ where
 
 /// Ensure that either this query has all green inputs or been executed.
 /// Executing `query::ensure(D)` is considered a read of the dep-node `D`.
+/// Returns true if the query should still run.
 ///
 /// This function is particularly useful when executing passes for their
 /// side-effects -- e.g., in order to report errors for erroneous programs.
 ///
 /// Note: The optimization is only available during incr. comp.
 #[inline(never)]
-fn ensure_query_impl<CTX, C>(
-    tcx: CTX,
-    state: &QueryState<CTX::DepKind, CTX::Query, C>,
-    key: C::Key,
-    query: &QueryVtable<CTX, C::Key, C::Value>,
-) where
-    C: QueryCache,
-    C::Key: crate::dep_graph::DepNodeParams<CTX>,
+fn ensure_must_run<CTX, K, V>(tcx: CTX, key: &K, query: &QueryVtable<CTX, K, V>) -> bool
+where
+    K: crate::dep_graph::DepNodeParams<CTX>,
     CTX: QueryContext,
 {
     if query.eval_always {
-        let _ = get_query_impl(tcx, state, DUMMY_SP, key, query);
-        return;
+        return true;
     }
 
     // Ensuring an anonymous query makes no sense
     assert!(!query.anon);
 
-    let dep_node = query.to_dep_node(tcx, &key);
+    let dep_node = query.to_dep_node(tcx, key);
 
     match tcx.dep_graph().try_mark_green_and_read(tcx, &dep_node) {
         None => {
@@ -675,10 +669,11 @@ fn ensure_query_impl<CTX, C>(
             // DepNodeIndex. We must invoke the query itself. The performance cost
             // this introduces should be negligible as we'll immediately hit the
             // in-memory cache, or another query down the line will.
-            let _ = get_query_impl(tcx, state, DUMMY_SP, key, query);
+            true
         }
         Some((_, dep_node_index)) => {
             tcx.profiler().query_cache_hit(dep_node_index.into());
+            false
         }
     }
 }
@@ -720,24 +715,27 @@ fn force_query_impl<CTX, C>(
     );
 }
 
-pub fn get_query<Q, CTX>(tcx: CTX, span: Span, key: Q::Key) -> Q::Stored
-where
-    Q: QueryDescription<CTX>,
-    Q::Key: crate::dep_graph::DepNodeParams<CTX>,
-    CTX: QueryContext,
-{
-    debug!("ty::query::get_query<{}>(key={:?}, span={:?})", Q::NAME, key, span);
-
-    get_query_impl(tcx, Q::query_state(tcx), span, key, &Q::VTABLE)
+pub enum QueryMode {
+    Get,
+    Ensure,
 }
 
-pub fn ensure_query<Q, CTX>(tcx: CTX, key: Q::Key)
+pub fn get_query<Q, CTX>(tcx: CTX, span: Span, key: Q::Key, mode: QueryMode) -> Option<Q::Stored>
 where
     Q: QueryDescription<CTX>,
     Q::Key: crate::dep_graph::DepNodeParams<CTX>,
     CTX: QueryContext,
 {
-    ensure_query_impl(tcx, Q::query_state(tcx), key, &Q::VTABLE)
+    let query = &Q::VTABLE;
+    if let QueryMode::Ensure = mode {
+        if !ensure_must_run(tcx, &key, query) {
+            return None;
+        }
+    }
+
+    debug!("ty::query::get_query<{}>(key={:?}, span={:?})", Q::NAME, key, span);
+    let value = get_query_impl(tcx, Q::query_state(tcx), span, key, query);
+    Some(value)
 }
 
 pub fn force_query<Q, CTX>(tcx: CTX, key: Q::Key, span: Span, dep_node: DepNode<CTX::DepKind>)

From 4b42a6d90b850eb697a56bddb9e3239d7e5c72fb Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sun, 17 Jan 2021 14:57:07 +0100
Subject: [PATCH 2/7] Introduce query_stored module.

---
 .../rustc_middle/src/ty/query/plumbing.rs     | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/query/plumbing.rs b/compiler/rustc_middle/src/ty/query/plumbing.rs
index f6370452e80b2..7a46bad0c1fd7 100644
--- a/compiler/rustc_middle/src/ty/query/plumbing.rs
+++ b/compiler/rustc_middle/src/ty/query/plumbing.rs
@@ -342,14 +342,20 @@ macro_rules! define_queries {
 
             $(pub type $name<$tcx> = $V;)*
         }
+        #[allow(nonstandard_style, unused_lifetimes)]
+        pub mod query_stored {
+            use super::*;
+
+            $(pub type $name<$tcx> = <
+                query_storage!([$($modifiers)*][$($K)*, $V])
+                as QueryStorage
+            >::Stored;)*
+        }
 
         $(impl<$tcx> QueryConfig for queries::$name<$tcx> {
             type Key = $($K)*;
             type Value = $V;
-            type Stored = <
-                query_storage!([$($modifiers)*][$($K)*, $V])
-                as QueryStorage
-            >::Stored;
+            type Stored = query_stored::$name<$tcx>;
             const NAME: &'static str = stringify!($name);
         }
 
@@ -442,8 +448,7 @@ macro_rules! define_queries {
             $($(#[$attr])*
             #[inline(always)]
             #[must_use]
-            pub fn $name(self, key: query_helper_param_ty!($($K)*))
-                -> <queries::$name<$tcx> as QueryConfig>::Stored
+            pub fn $name(self, key: query_helper_param_ty!($($K)*)) -> query_stored::$name<$tcx>
             {
                 self.at(DUMMY_SP).$name(key.into_query_param())
             })*
@@ -481,8 +486,7 @@ macro_rules! define_queries {
         impl TyCtxtAt<$tcx> {
             $($(#[$attr])*
             #[inline(always)]
-            pub fn $name(self, key: query_helper_param_ty!($($K)*))
-                -> <queries::$name<$tcx> as QueryConfig>::Stored
+            pub fn $name(self, key: query_helper_param_ty!($($K)*)) -> query_stored::$name<$tcx>
             {
                 get_query::<queries::$name<'_>, _>(self.tcx, self.span, key.into_query_param(), QueryMode::Get).unwrap()
             })*

From f8ab649dfd8866e35e3281e04534fe024e4095f7 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Tue, 19 Jan 2021 20:02:05 +0100
Subject: [PATCH 3/7] Introduce query_storage.

---
 compiler/rustc_middle/src/ty/query/plumbing.rs | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/query/plumbing.rs b/compiler/rustc_middle/src/ty/query/plumbing.rs
index 7a46bad0c1fd7..dcfc116585b9e 100644
--- a/compiler/rustc_middle/src/ty/query/plumbing.rs
+++ b/compiler/rustc_middle/src/ty/query/plumbing.rs
@@ -343,13 +343,16 @@ macro_rules! define_queries {
             $(pub type $name<$tcx> = $V;)*
         }
         #[allow(nonstandard_style, unused_lifetimes)]
+        pub mod query_storage {
+            use super::*;
+
+            $(pub type $name<$tcx> = query_storage!([$($modifiers)*][$($K)*, $V]);)*
+        }
+        #[allow(nonstandard_style, unused_lifetimes)]
         pub mod query_stored {
             use super::*;
 
-            $(pub type $name<$tcx> = <
-                query_storage!([$($modifiers)*][$($K)*, $V])
-                as QueryStorage
-            >::Stored;)*
+            $(pub type $name<$tcx> = <query_storage::$name<$tcx> as QueryStorage>::Stored;)*
         }
 
         $(impl<$tcx> QueryConfig for queries::$name<$tcx> {
@@ -364,7 +367,7 @@ macro_rules! define_queries {
             const EVAL_ALWAYS: bool = is_eval_always!([$($modifiers)*]);
             const DEP_KIND: dep_graph::DepKind = dep_graph::DepKind::$name;
 
-            type Cache = query_storage!([$($modifiers)*][$($K)*, $V]);
+            type Cache = query_storage::$name<$tcx>;
 
             #[inline(always)]
             fn query_state<'a>(tcx: TyCtxt<$tcx>) -> &'a QueryState<crate::dep_graph::DepKind, <TyCtxt<$tcx> as QueryContext>::Query, Self::Cache> {
@@ -523,7 +526,7 @@ macro_rules! define_queries_struct {
             $($(#[$attr])*  $name: QueryState<
                 crate::dep_graph::DepKind,
                 <TyCtxt<$tcx> as QueryContext>::Query,
-                <queries::$name<$tcx> as QueryAccessors<TyCtxt<'tcx>>>::Cache,
+                query_storage::$name<$tcx>,
             >,)*
         }
 

From 9f46259a7516f0bc453f9a0edb318be11c3d4a28 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Fri, 23 Oct 2020 22:34:32 +0200
Subject: [PATCH 4/7] Return a Result for query cache.

---
 .../rustc_query_system/src/query/caches.rs    | 56 +++++------
 .../rustc_query_system/src/query/plumbing.rs  | 99 ++++++++-----------
 2 files changed, 68 insertions(+), 87 deletions(-)

diff --git a/compiler/rustc_query_system/src/query/caches.rs b/compiler/rustc_query_system/src/query/caches.rs
index 1d2bc1a99a596..1ec32939d9f8d 100644
--- a/compiler/rustc_query_system/src/query/caches.rs
+++ b/compiler/rustc_query_system/src/query/caches.rs
@@ -31,17 +31,15 @@ pub trait QueryCache: QueryStorage {
     /// It returns the shard index and a lock guard to the shard,
     /// which will be used if the query is not in the cache and we need
     /// to compute it.
-    fn lookup<D, Q, R, OnHit, OnMiss>(
+    fn lookup<'s, D, Q, R, OnHit>(
         &self,
-        state: &QueryState<D, Q, Self>,
-        key: Self::Key,
+        state: &'s QueryState<D, Q, Self>,
+        key: &Self::Key,
         // `on_hit` can be called while holding a lock to the query state shard.
         on_hit: OnHit,
-        on_miss: OnMiss,
-    ) -> R
+    ) -> Result<R, QueryLookup<'s, D, Q, Self::Key, Self::Sharded>>
     where
-        OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R,
-        OnMiss: FnOnce(Self::Key, QueryLookup<'_, D, Q, Self::Key, Self::Sharded>) -> R;
+        OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R;
 
     fn complete(
         &self,
@@ -95,23 +93,24 @@ where
     type Sharded = FxHashMap<K, (V, DepNodeIndex)>;
 
     #[inline(always)]
-    fn lookup<D, Q, R, OnHit, OnMiss>(
+    fn lookup<'s, D, Q, R, OnHit>(
         &self,
-        state: &QueryState<D, Q, Self>,
-        key: K,
+        state: &'s QueryState<D, Q, Self>,
+        key: &K,
         on_hit: OnHit,
-        on_miss: OnMiss,
-    ) -> R
+    ) -> Result<R, QueryLookup<'s, D, Q, K, Self::Sharded>>
     where
         OnHit: FnOnce(&V, DepNodeIndex) -> R,
-        OnMiss: FnOnce(K, QueryLookup<'_, D, Q, K, Self::Sharded>) -> R,
     {
-        let mut lookup = state.get_lookup(&key);
-        let lock = &mut *lookup.lock;
+        let lookup = state.get_lookup(key);
+        let result = lookup.lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
-        let result = lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, &key);
-
-        if let Some((_, value)) = result { on_hit(&value.0, value.1) } else { on_miss(key, lookup) }
+        if let Some((_, value)) = result {
+            let hit_result = on_hit(&value.0, value.1);
+            Ok(hit_result)
+        } else {
+            Err(lookup)
+        }
     }
 
     #[inline]
@@ -177,26 +176,23 @@ where
     type Sharded = FxHashMap<K, &'tcx (V, DepNodeIndex)>;
 
     #[inline(always)]
-    fn lookup<D, Q, R, OnHit, OnMiss>(
+    fn lookup<'s, D, Q, R, OnHit>(
         &self,
-        state: &QueryState<D, Q, Self>,
-        key: K,
+        state: &'s QueryState<D, Q, Self>,
+        key: &K,
         on_hit: OnHit,
-        on_miss: OnMiss,
-    ) -> R
+    ) -> Result<R, QueryLookup<'s, D, Q, K, Self::Sharded>>
     where
         OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
-        OnMiss: FnOnce(K, QueryLookup<'_, D, Q, K, Self::Sharded>) -> R,
     {
-        let mut lookup = state.get_lookup(&key);
-        let lock = &mut *lookup.lock;
-
-        let result = lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, &key);
+        let lookup = state.get_lookup(key);
+        let result = lookup.lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
         if let Some((_, value)) = result {
-            on_hit(&&value.0, value.1)
+            let hit_result = on_hit(&&value.0, value.1);
+            Ok(hit_result)
         } else {
-            on_miss(key, lookup)
+            Err(lookup)
         }
     }
 
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index f2ebf8d7d3d08..4f93017200f59 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -248,13 +248,8 @@ where
                 return TryGetJob::Cycle(value);
             }
 
-            let cached = try_get_cached(
-                tcx,
-                state,
-                (*key).clone(),
-                |value, index| (value.clone(), index),
-                |_, _| panic!("value must be in cache after waiting"),
-            );
+            let cached = try_get_cached(tcx, state, key, |value, index| (value.clone(), index))
+                .unwrap_or_else(|_| panic!("value must be in cache after waiting"));
 
             if let Some(prof_timer) = _query_blocked_prof_timer.take() {
                 prof_timer.finish_with_query_invocation_id(cached.1.into());
@@ -356,35 +351,28 @@ where
 /// It returns the shard index and a lock guard to the shard,
 /// which will be used if the query is not in the cache and we need
 /// to compute it.
-fn try_get_cached<CTX, C, R, OnHit, OnMiss>(
+fn try_get_cached<'a, CTX, C, R, OnHit>(
     tcx: CTX,
-    state: &QueryState<CTX::DepKind, CTX::Query, C>,
-    key: C::Key,
+    state: &'a QueryState<CTX::DepKind, CTX::Query, C>,
+    key: &C::Key,
     // `on_hit` can be called while holding a lock to the query cache
     on_hit: OnHit,
-    on_miss: OnMiss,
-) -> R
+) -> Result<R, QueryLookup<'a, CTX::DepKind, CTX::Query, C::Key, C::Sharded>>
 where
     C: QueryCache,
     CTX: QueryContext,
     OnHit: FnOnce(&C::Stored, DepNodeIndex) -> R,
-    OnMiss: FnOnce(C::Key, QueryLookup<'_, CTX::DepKind, CTX::Query, C::Key, C::Sharded>) -> R,
 {
-    state.cache.lookup(
-        state,
-        key,
-        |value, index| {
-            if unlikely!(tcx.profiler().enabled()) {
-                tcx.profiler().query_cache_hit(index.into());
-            }
-            #[cfg(debug_assertions)]
-            {
-                state.cache_hits.fetch_add(1, Ordering::Relaxed);
-            }
-            on_hit(value, index)
-        },
-        on_miss,
-    )
+    state.cache.lookup(state, &key, |value, index| {
+        if unlikely!(tcx.profiler().enabled()) {
+            tcx.profiler().query_cache_hit(index.into());
+        }
+        #[cfg(debug_assertions)]
+        {
+            state.cache_hits.fetch_add(1, Ordering::Relaxed);
+        }
+        on_hit(value, index)
+    })
 }
 
 fn try_execute_query<CTX, C>(
@@ -626,16 +614,14 @@ where
     C: QueryCache,
     C::Key: crate::dep_graph::DepNodeParams<CTX>,
 {
-    try_get_cached(
-        tcx,
-        state,
-        key,
-        |value, index| {
-            tcx.dep_graph().read_index(index);
-            value.clone()
-        },
-        |key, lookup| try_execute_query(tcx, state, span, key, lookup, query),
-    )
+    let cached = try_get_cached(tcx, state, &key, |value, index| {
+        tcx.dep_graph().read_index(index);
+        value.clone()
+    });
+    match cached {
+        Ok(value) => value,
+        Err(lookup) => try_execute_query(tcx, state, span, key, lookup, query),
+    }
 }
 
 /// Ensure that either this query has all green inputs or been executed.
@@ -694,25 +680,24 @@ fn force_query_impl<CTX, C>(
     // We may be concurrently trying both execute and force a query.
     // Ensure that only one of them runs the query.
 
-    try_get_cached(
-        tcx,
-        state,
-        key,
-        |_, _| {
-            // Cache hit, do nothing
-        },
-        |key, lookup| {
-            let job = match JobOwner::<'_, CTX::DepKind, CTX::Query, C>::try_start(
-                tcx, state, span, &key, lookup, query,
-            ) {
-                TryGetJob::NotYetStarted(job) => job,
-                TryGetJob::Cycle(_) => return,
-                #[cfg(parallel_compiler)]
-                TryGetJob::JobCompleted(_) => return,
-            };
-            force_query_with_job(tcx, key, job, dep_node, query);
-        },
-    );
+    let cached = try_get_cached(tcx, state, &key, |_, _| {
+        // Cache hit, do nothing
+    });
+
+    let lookup = match cached {
+        Ok(()) => return,
+        Err(lookup) => lookup,
+    };
+
+    let job = match JobOwner::<'_, CTX::DepKind, CTX::Query, C>::try_start(
+        tcx, state, span, &key, lookup, query,
+    ) {
+        TryGetJob::NotYetStarted(job) => job,
+        TryGetJob::Cycle(_) => return,
+        #[cfg(parallel_compiler)]
+        TryGetJob::JobCompleted(_) => return,
+    };
+    force_query_with_job(tcx, key, job, dep_node, query);
 }
 
 pub enum QueryMode {

From 15b0bc6b8380942fb45f1839b9fd91e66fad8045 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sat, 6 Feb 2021 13:49:08 +0100
Subject: [PATCH 5/7] Separate the query cache from the query state.

---
 compiler/rustc_data_structures/src/sharded.rs |  30 +--
 compiler/rustc_middle/src/ty/context.rs       |   2 +
 .../src/ty/query/on_disk_cache.rs             |   7 +-
 .../rustc_middle/src/ty/query/plumbing.rs     |  20 +-
 .../src/ty/query/profiling_support.rs         |   8 +-
 compiler/rustc_middle/src/ty/query/stats.rs   |  11 +-
 .../rustc_query_system/src/query/caches.rs    |  24 +--
 .../rustc_query_system/src/query/config.rs    |   9 +-
 .../rustc_query_system/src/query/plumbing.rs  | 197 ++++++++++--------
 9 files changed, 173 insertions(+), 135 deletions(-)

diff --git a/compiler/rustc_data_structures/src/sharded.rs b/compiler/rustc_data_structures/src/sharded.rs
index 485719c517564..14db71cb8f070 100644
--- a/compiler/rustc_data_structures/src/sharded.rs
+++ b/compiler/rustc_data_structures/src/sharded.rs
@@ -63,23 +63,9 @@ impl<T> Sharded<T> {
         if SHARDS == 1 { &self.shards[0].0 } else { self.get_shard_by_hash(make_hash(val)) }
     }
 
-    /// Get a shard with a pre-computed hash value. If `get_shard_by_value` is
-    /// ever used in combination with `get_shard_by_hash` on a single `Sharded`
-    /// instance, then `hash` must be computed with `FxHasher`. Otherwise,
-    /// `hash` can be computed with any hasher, so long as that hasher is used
-    /// consistently for each `Sharded` instance.
-    #[inline]
-    pub fn get_shard_index_by_hash(&self, hash: u64) -> usize {
-        let hash_len = mem::size_of::<usize>();
-        // Ignore the top 7 bits as hashbrown uses these and get the next SHARD_BITS highest bits.
-        // hashbrown also uses the lowest bits, so we can't use those
-        let bits = (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize;
-        bits % SHARDS
-    }
-
     #[inline]
     pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
-        &self.shards[self.get_shard_index_by_hash(hash)].0
+        &self.shards[get_shard_index_by_hash(hash)].0
     }
 
     #[inline]
@@ -166,3 +152,17 @@ fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
     val.hash(&mut state);
     state.finish()
 }
+
+/// Get a shard with a pre-computed hash value. If `get_shard_by_value` is
+/// ever used in combination with `get_shard_by_hash` on a single `Sharded`
+/// instance, then `hash` must be computed with `FxHasher`. Otherwise,
+/// `hash` can be computed with any hasher, so long as that hasher is used
+/// consistently for each `Sharded` instance.
+#[inline]
+pub fn get_shard_index_by_hash(hash: u64) -> usize {
+    let hash_len = mem::size_of::<usize>();
+    // Ignore the top 7 bits as hashbrown uses these and get the next SHARD_BITS highest bits.
+    // hashbrown also uses the lowest bits, so we can't use those
+    let bits = (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize;
+    bits % SHARDS
+}
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index f83056ebe2a45..4654a8424706d 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -963,6 +963,7 @@ pub struct GlobalCtxt<'tcx> {
     pub(crate) definitions: &'tcx Definitions,
 
     pub queries: query::Queries<'tcx>,
+    pub query_caches: query::QueryCaches<'tcx>,
 
     maybe_unused_trait_imports: FxHashSet<LocalDefId>,
     maybe_unused_extern_crates: Vec<(LocalDefId, Span)>,
@@ -1154,6 +1155,7 @@ impl<'tcx> TyCtxt<'tcx> {
             untracked_crate: krate,
             definitions,
             queries: query::Queries::new(providers, extern_providers, on_disk_query_result_cache),
+            query_caches: query::QueryCaches::default(),
             ty_rcache: Default::default(),
             pred_rcache: Default::default(),
             selection_cache: Default::default(),
diff --git a/compiler/rustc_middle/src/ty/query/on_disk_cache.rs b/compiler/rustc_middle/src/ty/query/on_disk_cache.rs
index cfe47004e01b6..b41edb5deeb2c 100644
--- a/compiler/rustc_middle/src/ty/query/on_disk_cache.rs
+++ b/compiler/rustc_middle/src/ty/query/on_disk_cache.rs
@@ -1244,10 +1244,9 @@ where
         .prof
         .extra_verbose_generic_activity("encode_query_results_for", std::any::type_name::<Q>());
 
-    let state = Q::query_state(tcx);
-    assert!(state.all_inactive());
-
-    state.iter_results(|results| {
+    assert!(Q::query_state(tcx).all_inactive());
+    let cache = Q::query_cache(tcx);
+    cache.iter_results(|results| {
         for (key, value, dep_node) in results {
             if Q::cache_on_disk(tcx, &key, Some(value)) {
                 let dep_node = SerializedDepNodeIndex::new(dep_node.index());
diff --git a/compiler/rustc_middle/src/ty/query/plumbing.rs b/compiler/rustc_middle/src/ty/query/plumbing.rs
index dcfc116585b9e..9a011846fd62d 100644
--- a/compiler/rustc_middle/src/ty/query/plumbing.rs
+++ b/compiler/rustc_middle/src/ty/query/plumbing.rs
@@ -355,6 +355,11 @@ macro_rules! define_queries {
             $(pub type $name<$tcx> = <query_storage::$name<$tcx> as QueryStorage>::Stored;)*
         }
 
+        #[derive(Default)]
+        pub struct QueryCaches<$tcx> {
+            $($(#[$attr])* $name: QueryCacheStore<query_storage::$name<$tcx>>,)*
+        }
+
         $(impl<$tcx> QueryConfig for queries::$name<$tcx> {
             type Key = $($K)*;
             type Value = $V;
@@ -370,10 +375,17 @@ macro_rules! define_queries {
             type Cache = query_storage::$name<$tcx>;
 
             #[inline(always)]
-            fn query_state<'a>(tcx: TyCtxt<$tcx>) -> &'a QueryState<crate::dep_graph::DepKind, <TyCtxt<$tcx> as QueryContext>::Query, Self::Cache> {
+            fn query_state<'a>(tcx: TyCtxt<$tcx>) -> &'a QueryState<crate::dep_graph::DepKind, Query<$tcx>, Self::Key> {
                 &tcx.queries.$name
             }
 
+            #[inline(always)]
+            fn query_cache<'a>(tcx: TyCtxt<$tcx>) -> &'a QueryCacheStore<Self::Cache>
+                where 'tcx:'a
+            {
+                &tcx.query_caches.$name
+            }
+
             #[inline]
             fn compute(tcx: TyCtxt<'tcx>, key: Self::Key) -> Self::Value {
                 let provider = tcx.queries.providers.get(key.query_crate())
@@ -479,7 +491,7 @@ macro_rules! define_queries {
                     alloc_self_profile_query_strings_for_query_cache(
                         self,
                         stringify!($name),
-                        &self.queries.$name,
+                        &self.query_caches.$name,
                         &mut string_cache,
                     );
                 })*
@@ -525,8 +537,8 @@ macro_rules! define_queries_struct {
 
             $($(#[$attr])*  $name: QueryState<
                 crate::dep_graph::DepKind,
-                <TyCtxt<$tcx> as QueryContext>::Query,
-                query_storage::$name<$tcx>,
+                Query<$tcx>,
+                query_keys::$name<$tcx>,
             >,)*
         }
 
diff --git a/compiler/rustc_middle/src/ty/query/profiling_support.rs b/compiler/rustc_middle/src/ty/query/profiling_support.rs
index cbcecb8849188..9976e7885090c 100644
--- a/compiler/rustc_middle/src/ty/query/profiling_support.rs
+++ b/compiler/rustc_middle/src/ty/query/profiling_support.rs
@@ -5,7 +5,7 @@ use rustc_data_structures::fx::FxHashMap;
 use rustc_data_structures::profiling::SelfProfiler;
 use rustc_hir::def_id::{CrateNum, DefId, DefIndex, LocalDefId, CRATE_DEF_INDEX, LOCAL_CRATE};
 use rustc_hir::definitions::DefPathData;
-use rustc_query_system::query::{QueryCache, QueryContext, QueryState};
+use rustc_query_system::query::{QueryCache, QueryCacheStore};
 use std::fmt::Debug;
 use std::io::Write;
 
@@ -230,7 +230,7 @@ where
 pub(super) fn alloc_self_profile_query_strings_for_query_cache<'tcx, C>(
     tcx: TyCtxt<'tcx>,
     query_name: &'static str,
-    query_state: &QueryState<crate::dep_graph::DepKind, <TyCtxt<'tcx> as QueryContext>::Query, C>,
+    query_cache: &QueryCacheStore<C>,
     string_cache: &mut QueryKeyStringCache,
 ) where
     C: QueryCache,
@@ -251,7 +251,7 @@ pub(super) fn alloc_self_profile_query_strings_for_query_cache<'tcx, C>(
             // need to invoke queries itself, we cannot keep the query caches
             // locked while doing so. Instead we copy out the
             // `(query_key, dep_node_index)` pairs and release the lock again.
-            let query_keys_and_indices: Vec<_> = query_state
+            let query_keys_and_indices: Vec<_> = query_cache
                 .iter_results(|results| results.map(|(k, _, i)| (k.clone(), i)).collect());
 
             // Now actually allocate the strings. If allocating the strings
@@ -276,7 +276,7 @@ pub(super) fn alloc_self_profile_query_strings_for_query_cache<'tcx, C>(
             let query_name = profiler.get_or_alloc_cached_string(query_name);
             let event_id = event_id_builder.from_label(query_name).to_string_id();
 
-            query_state.iter_results(|results| {
+            query_cache.iter_results(|results| {
                 let query_invocation_ids: Vec<_> = results.map(|v| v.2.into()).collect();
 
                 profiler.bulk_map_query_invocation_id_to_single_string(
diff --git a/compiler/rustc_middle/src/ty/query/stats.rs b/compiler/rustc_middle/src/ty/query/stats.rs
index e0b44ce23c912..c885a10f80595 100644
--- a/compiler/rustc_middle/src/ty/query/stats.rs
+++ b/compiler/rustc_middle/src/ty/query/stats.rs
@@ -1,10 +1,9 @@
 use crate::ty::query::queries;
 use crate::ty::TyCtxt;
 use rustc_hir::def_id::{DefId, LOCAL_CRATE};
-use rustc_query_system::query::{QueryAccessors, QueryCache, QueryContext, QueryState};
+use rustc_query_system::query::{QueryAccessors, QueryCache, QueryCacheStore};
 
 use std::any::type_name;
-use std::hash::Hash;
 use std::mem;
 #[cfg(debug_assertions)]
 use std::sync::atomic::Ordering;
@@ -37,10 +36,8 @@ struct QueryStats {
     local_def_id_keys: Option<usize>,
 }
 
-fn stats<D, Q, C>(name: &'static str, map: &QueryState<D, Q, C>) -> QueryStats
+fn stats<C>(name: &'static str, map: &QueryCacheStore<C>) -> QueryStats
 where
-    D: Copy + Clone + Eq + Hash,
-    Q: Clone,
     C: QueryCache,
 {
     let mut stats = QueryStats {
@@ -128,12 +125,10 @@ macro_rules! print_stats {
 
             $(
                 queries.push(stats::<
-                    crate::dep_graph::DepKind,
-                    <TyCtxt<'_> as QueryContext>::Query,
                     <queries::$name<'_> as QueryAccessors<TyCtxt<'_>>>::Cache,
                 >(
                     stringify!($name),
-                    &tcx.queries.$name,
+                    &tcx.query_caches.$name,
                 ));
             )*
 
diff --git a/compiler/rustc_query_system/src/query/caches.rs b/compiler/rustc_query_system/src/query/caches.rs
index 1ec32939d9f8d..d589c90fa7b12 100644
--- a/compiler/rustc_query_system/src/query/caches.rs
+++ b/compiler/rustc_query_system/src/query/caches.rs
@@ -1,5 +1,5 @@
 use crate::dep_graph::DepNodeIndex;
-use crate::query::plumbing::{QueryLookup, QueryState};
+use crate::query::plumbing::{QueryCacheStore, QueryLookup};
 
 use rustc_arena::TypedArena;
 use rustc_data_structures::fx::FxHashMap;
@@ -31,13 +31,13 @@ pub trait QueryCache: QueryStorage {
     /// It returns the shard index and a lock guard to the shard,
     /// which will be used if the query is not in the cache and we need
     /// to compute it.
-    fn lookup<'s, D, Q, R, OnHit>(
+    fn lookup<'s, R, OnHit>(
         &self,
-        state: &'s QueryState<D, Q, Self>,
+        state: &'s QueryCacheStore<Self>,
         key: &Self::Key,
         // `on_hit` can be called while holding a lock to the query state shard.
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, D, Q, Self::Key, Self::Sharded>>
+    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
     where
         OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R;
 
@@ -93,17 +93,17 @@ where
     type Sharded = FxHashMap<K, (V, DepNodeIndex)>;
 
     #[inline(always)]
-    fn lookup<'s, D, Q, R, OnHit>(
+    fn lookup<'s, R, OnHit>(
         &self,
-        state: &'s QueryState<D, Q, Self>,
+        state: &'s QueryCacheStore<Self>,
         key: &K,
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, D, Q, K, Self::Sharded>>
+    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
     where
         OnHit: FnOnce(&V, DepNodeIndex) -> R,
     {
         let lookup = state.get_lookup(key);
-        let result = lookup.lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
+        let result = lookup.lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
         if let Some((_, value)) = result {
             let hit_result = on_hit(&value.0, value.1);
@@ -176,17 +176,17 @@ where
     type Sharded = FxHashMap<K, &'tcx (V, DepNodeIndex)>;
 
     #[inline(always)]
-    fn lookup<'s, D, Q, R, OnHit>(
+    fn lookup<'s, R, OnHit>(
         &self,
-        state: &'s QueryState<D, Q, Self>,
+        state: &'s QueryCacheStore<Self>,
         key: &K,
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, D, Q, K, Self::Sharded>>
+    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
     where
         OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
     {
         let lookup = state.get_lookup(key);
-        let result = lookup.lock.cache.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
+        let result = lookup.lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
         if let Some((_, value)) = result {
             let hit_result = on_hit(&&value.0, value.1);
diff --git a/compiler/rustc_query_system/src/query/config.rs b/compiler/rustc_query_system/src/query/config.rs
index 94e906fc433d5..fecd75049fb7a 100644
--- a/compiler/rustc_query_system/src/query/config.rs
+++ b/compiler/rustc_query_system/src/query/config.rs
@@ -4,7 +4,7 @@ use crate::dep_graph::DepNode;
 use crate::dep_graph::SerializedDepNodeIndex;
 use crate::query::caches::QueryCache;
 use crate::query::plumbing::CycleError;
-use crate::query::{QueryContext, QueryState};
+use crate::query::{QueryCacheStore, QueryContext, QueryState};
 
 use rustc_data_structures::fingerprint::Fingerprint;
 use std::fmt::Debug;
@@ -73,7 +73,12 @@ pub trait QueryAccessors<CTX: QueryContext>: QueryConfig {
     type Cache: QueryCache<Key = Self::Key, Stored = Self::Stored, Value = Self::Value>;
 
     // Don't use this method to access query results, instead use the methods on TyCtxt
-    fn query_state<'a>(tcx: CTX) -> &'a QueryState<CTX::DepKind, CTX::Query, Self::Cache>;
+    fn query_state<'a>(tcx: CTX) -> &'a QueryState<CTX::DepKind, CTX::Query, Self::Key>;
+
+    // Don't use this method to access query results, instead use the methods on TyCtxt
+    fn query_cache<'a>(tcx: CTX) -> &'a QueryCacheStore<Self::Cache>
+    where
+        CTX: 'a;
 
     fn to_dep_node(tcx: CTX, key: &Self::Key) -> DepNode<CTX::DepKind>
     where
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index 4f93017200f59..51a72594b5e0c 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -13,7 +13,7 @@ use crate::query::{QueryContext, QueryMap};
 use rustc_data_structures::cold_path;
 use rustc_data_structures::fingerprint::Fingerprint;
 use rustc_data_structures::fx::{FxHashMap, FxHasher};
-use rustc_data_structures::sharded::Sharded;
+use rustc_data_structures::sharded::{get_shard_index_by_hash, Sharded};
 use rustc_data_structures::sync::{Lock, LockGuard};
 use rustc_data_structures::thin_vec::ThinVec;
 use rustc_errors::{Diagnostic, FatalError};
@@ -27,43 +27,73 @@ use std::ptr;
 #[cfg(debug_assertions)]
 use std::sync::atomic::{AtomicUsize, Ordering};
 
-pub(super) struct QueryStateShard<D, Q, K, C> {
-    pub(super) cache: C,
-    active: FxHashMap<K, QueryResult<D, Q>>,
-
-    /// Used to generate unique ids for active jobs.
-    jobs: u32,
+pub struct QueryCacheStore<C: QueryCache> {
+    cache: C,
+    shards: Sharded<C::Sharded>,
+    #[cfg(debug_assertions)]
+    pub cache_hits: AtomicUsize,
 }
 
-impl<D, Q, K, C: Default> Default for QueryStateShard<D, Q, K, C> {
-    fn default() -> QueryStateShard<D, Q, K, C> {
-        QueryStateShard { cache: Default::default(), active: Default::default(), jobs: 0 }
+impl<C: QueryCache> Default for QueryCacheStore<C> {
+    fn default() -> Self {
+        Self {
+            cache: C::default(),
+            shards: Default::default(),
+            #[cfg(debug_assertions)]
+            cache_hits: AtomicUsize::new(0),
+        }
     }
 }
 
-pub struct QueryState<D, Q, C: QueryCache> {
-    cache: C,
-    shards: Sharded<QueryStateShard<D, Q, C::Key, C::Sharded>>,
-    #[cfg(debug_assertions)]
-    pub cache_hits: AtomicUsize,
+/// Values used when checking a query cache which can be reused on a cache-miss to execute the query.
+pub struct QueryLookup<'tcx, C> {
+    pub(super) key_hash: u64,
+    shard: usize,
+    pub(super) lock: LockGuard<'tcx, C>,
 }
 
-impl<D, Q, C: QueryCache> QueryState<D, Q, C> {
-    pub(super) fn get_lookup<'tcx>(
-        &'tcx self,
-        key: &C::Key,
-    ) -> QueryLookup<'tcx, D, Q, C::Key, C::Sharded> {
-        // We compute the key's hash once and then use it for both the
-        // shard lookup and the hashmap lookup. This relies on the fact
-        // that both of them use `FxHasher`.
-        let mut hasher = FxHasher::default();
-        key.hash(&mut hasher);
-        let key_hash = hasher.finish();
-
-        let shard = self.shards.get_shard_index_by_hash(key_hash);
+// We compute the key's hash once and then use it for both the
+// shard lookup and the hashmap lookup. This relies on the fact
+// that both of them use `FxHasher`.
+fn hash_for_shard<K: Hash>(key: &K) -> u64 {
+    let mut hasher = FxHasher::default();
+    key.hash(&mut hasher);
+    hasher.finish()
+}
+
+impl<C: QueryCache> QueryCacheStore<C> {
+    pub(super) fn get_lookup<'tcx>(&'tcx self, key: &C::Key) -> QueryLookup<'tcx, C::Sharded> {
+        let key_hash = hash_for_shard(key);
+        let shard = get_shard_index_by_hash(key_hash);
         let lock = self.shards.get_shard_by_index(shard).lock();
         QueryLookup { key_hash, shard, lock }
     }
+
+    pub fn iter_results<R>(
+        &self,
+        f: impl for<'a> FnOnce(
+            Box<dyn Iterator<Item = (&'a C::Key, &'a C::Value, DepNodeIndex)> + 'a>,
+        ) -> R,
+    ) -> R {
+        self.cache.iter(&self.shards, |shard| &mut *shard, f)
+    }
+}
+
+struct QueryStateShard<D, Q, K> {
+    active: FxHashMap<K, QueryResult<D, Q>>,
+
+    /// Used to generate unique ids for active jobs.
+    jobs: u32,
+}
+
+impl<D, Q, K> Default for QueryStateShard<D, Q, K> {
+    fn default() -> QueryStateShard<D, Q, K> {
+        QueryStateShard { active: Default::default(), jobs: 0 }
+    }
+}
+
+pub struct QueryState<D, Q, K> {
+    shards: Sharded<QueryStateShard<D, Q, K>>,
 }
 
 /// Indicates the state of a query for a given key in a query map.
@@ -76,21 +106,12 @@ enum QueryResult<D, Q> {
     Poisoned,
 }
 
-impl<D, Q, C> QueryState<D, Q, C>
+impl<D, Q, K> QueryState<D, Q, K>
 where
     D: Copy + Clone + Eq + Hash,
     Q: Clone,
-    C: QueryCache,
+    K: Eq + Hash + Clone + Debug,
 {
-    pub fn iter_results<R>(
-        &self,
-        f: impl for<'a> FnOnce(
-            Box<dyn Iterator<Item = (&'a C::Key, &'a C::Value, DepNodeIndex)> + 'a>,
-        ) -> R,
-    ) -> R {
-        self.cache.iter(&self.shards, |shard| &mut shard.cache, f)
-    }
-
     pub fn all_inactive(&self) -> bool {
         let shards = self.shards.lock_shards();
         shards.iter().all(|shard| shard.active.is_empty())
@@ -99,7 +120,7 @@ where
     pub fn try_collect_active_jobs(
         &self,
         kind: D,
-        make_query: fn(C::Key) -> Q,
+        make_query: fn(K) -> Q,
         jobs: &mut QueryMap<D, Q>,
     ) -> Option<()> {
         // We use try_lock_shards here since we are called from the
@@ -122,24 +143,12 @@ where
     }
 }
 
-impl<D, Q, C: QueryCache> Default for QueryState<D, Q, C> {
-    fn default() -> QueryState<D, Q, C> {
-        QueryState {
-            cache: C::default(),
-            shards: Default::default(),
-            #[cfg(debug_assertions)]
-            cache_hits: AtomicUsize::new(0),
-        }
+impl<D, Q, K> Default for QueryState<D, Q, K> {
+    fn default() -> QueryState<D, Q, K> {
+        QueryState { shards: Default::default() }
     }
 }
 
-/// Values used when checking a query cache which can be reused on a cache-miss to execute the query.
-pub struct QueryLookup<'tcx, D, Q, K, C> {
-    pub(super) key_hash: u64,
-    shard: usize,
-    pub(super) lock: LockGuard<'tcx, QueryStateShard<D, Q, K, C>>,
-}
-
 /// A type representing the responsibility to execute the job in the `job` field.
 /// This will poison the relevant query if dropped.
 struct JobOwner<'tcx, D, Q, C>
@@ -148,7 +157,8 @@ where
     Q: Clone,
     C: QueryCache,
 {
-    state: &'tcx QueryState<D, Q, C>,
+    state: &'tcx QueryState<D, Q, C::Key>,
+    cache: &'tcx QueryCacheStore<C>,
     key: C::Key,
     id: QueryJobId<D>,
 }
@@ -170,16 +180,20 @@ where
     #[inline(always)]
     fn try_start<'a, 'b, CTX>(
         tcx: CTX,
-        state: &'b QueryState<CTX::DepKind, CTX::Query, C>,
+        state: &'b QueryState<CTX::DepKind, CTX::Query, C::Key>,
+        cache: &'b QueryCacheStore<C>,
         span: Span,
         key: &C::Key,
-        mut lookup: QueryLookup<'a, CTX::DepKind, CTX::Query, C::Key, C::Sharded>,
+        lookup: QueryLookup<'a, C::Sharded>,
         query: &QueryVtable<CTX, C::Key, C::Value>,
     ) -> TryGetJob<'b, CTX::DepKind, CTX::Query, C>
     where
         CTX: QueryContext,
     {
-        let lock = &mut *lookup.lock;
+        mem::drop(lookup.lock);
+        let shard = lookup.shard;
+        let mut state_lock = state.shards.get_shard_by_index(shard).lock();
+        let lock = &mut *state_lock;
 
         let (latch, mut _query_blocked_prof_timer) = match lock.active.entry((*key).clone()) {
             Entry::Occupied(mut entry) => {
@@ -195,7 +209,7 @@ where
                         };
 
                         // Create the id of the job we're waiting for
-                        let id = QueryJobId::new(job.id, lookup.shard, query.dep_kind);
+                        let id = QueryJobId::new(job.id, shard, query.dep_kind);
 
                         (job.latch(id), _query_blocked_prof_timer)
                     }
@@ -210,18 +224,18 @@ where
                 lock.jobs = id;
                 let id = QueryShardJobId(NonZeroU32::new(id).unwrap());
 
-                let global_id = QueryJobId::new(id, lookup.shard, query.dep_kind);
+                let global_id = QueryJobId::new(id, shard, query.dep_kind);
 
                 let job = tcx.current_query_job();
                 let job = QueryJob::new(id, span, job);
 
                 entry.insert(QueryResult::Started(job));
 
-                let owner = JobOwner { state, id: global_id, key: (*key).clone() };
+                let owner = JobOwner { state, cache, id: global_id, key: (*key).clone() };
                 return TryGetJob::NotYetStarted(owner);
             }
         };
-        mem::drop(lookup.lock);
+        mem::drop(state_lock);
 
         // If we are single-threaded we know that we have cycle error,
         // so we just return the error.
@@ -233,7 +247,7 @@ where
                 span,
             );
             let value = query.handle_cycle_error(tcx, error);
-            state.cache.store_nocache(value)
+            cache.cache.store_nocache(value)
         }));
 
         // With parallel queries we might just have to wait on some other
@@ -244,11 +258,11 @@ where
 
             if let Err(cycle) = result {
                 let value = query.handle_cycle_error(tcx, cycle);
-                let value = state.cache.store_nocache(value);
+                let value = cache.cache.store_nocache(value);
                 return TryGetJob::Cycle(value);
             }
 
-            let cached = try_get_cached(tcx, state, key, |value, index| (value.clone(), index))
+            let cached = try_get_cached(tcx, cache, key, |value, index| (value.clone(), index))
                 .unwrap_or_else(|_| panic!("value must be in cache after waiting"));
 
             if let Some(prof_timer) = _query_blocked_prof_timer.take() {
@@ -265,17 +279,25 @@ where
         // We can move out of `self` here because we `mem::forget` it below
         let key = unsafe { ptr::read(&self.key) };
         let state = self.state;
+        let cache = self.cache;
 
         // Forget ourself so our destructor won't poison the query
         mem::forget(self);
 
         let (job, result) = {
-            let mut lock = state.shards.get_shard_by_value(&key).lock();
-            let job = match lock.active.remove(&key).unwrap() {
-                QueryResult::Started(job) => job,
-                QueryResult::Poisoned => panic!(),
+            let key_hash = hash_for_shard(&key);
+            let shard = get_shard_index_by_hash(key_hash);
+            let job = {
+                let mut lock = state.shards.get_shard_by_index(shard).lock();
+                match lock.active.remove(&key).unwrap() {
+                    QueryResult::Started(job) => job,
+                    QueryResult::Poisoned => panic!(),
+                }
+            };
+            let result = {
+                let mut lock = cache.shards.get_shard_by_index(shard).lock();
+                cache.cache.complete(&mut lock, key, result, dep_node_index)
             };
-            let result = state.cache.complete(&mut lock.cache, key, result, dep_node_index);
             (job, result)
         };
 
@@ -353,23 +375,23 @@ where
 /// to compute it.
 fn try_get_cached<'a, CTX, C, R, OnHit>(
     tcx: CTX,
-    state: &'a QueryState<CTX::DepKind, CTX::Query, C>,
+    cache: &'a QueryCacheStore<C>,
     key: &C::Key,
     // `on_hit` can be called while holding a lock to the query cache
     on_hit: OnHit,
-) -> Result<R, QueryLookup<'a, CTX::DepKind, CTX::Query, C::Key, C::Sharded>>
+) -> Result<R, QueryLookup<'a, C::Sharded>>
 where
     C: QueryCache,
     CTX: QueryContext,
     OnHit: FnOnce(&C::Stored, DepNodeIndex) -> R,
 {
-    state.cache.lookup(state, &key, |value, index| {
+    cache.cache.lookup(cache, &key, |value, index| {
         if unlikely!(tcx.profiler().enabled()) {
             tcx.profiler().query_cache_hit(index.into());
         }
         #[cfg(debug_assertions)]
         {
-            state.cache_hits.fetch_add(1, Ordering::Relaxed);
+            cache.cache_hits.fetch_add(1, Ordering::Relaxed);
         }
         on_hit(value, index)
     })
@@ -377,10 +399,11 @@ where
 
 fn try_execute_query<CTX, C>(
     tcx: CTX,
-    state: &QueryState<CTX::DepKind, CTX::Query, C>,
+    state: &QueryState<CTX::DepKind, CTX::Query, C::Key>,
+    cache: &QueryCacheStore<C>,
     span: Span,
     key: C::Key,
-    lookup: QueryLookup<'_, CTX::DepKind, CTX::Query, C::Key, C::Sharded>,
+    lookup: QueryLookup<'_, C::Sharded>,
     query: &QueryVtable<CTX, C::Key, C::Value>,
 ) -> C::Stored
 where
@@ -389,7 +412,7 @@ where
     CTX: QueryContext,
 {
     let job = match JobOwner::<'_, CTX::DepKind, CTX::Query, C>::try_start(
-        tcx, state, span, &key, lookup, query,
+        tcx, state, cache, span, &key, lookup, query,
     ) {
         TryGetJob::NotYetStarted(job) => job,
         TryGetJob::Cycle(result) => return result,
@@ -604,7 +627,8 @@ where
 #[inline(never)]
 fn get_query_impl<CTX, C>(
     tcx: CTX,
-    state: &QueryState<CTX::DepKind, CTX::Query, C>,
+    state: &QueryState<CTX::DepKind, CTX::Query, C::Key>,
+    cache: &QueryCacheStore<C>,
     span: Span,
     key: C::Key,
     query: &QueryVtable<CTX, C::Key, C::Value>,
@@ -614,13 +638,13 @@ where
     C: QueryCache,
     C::Key: crate::dep_graph::DepNodeParams<CTX>,
 {
-    let cached = try_get_cached(tcx, state, &key, |value, index| {
+    let cached = try_get_cached(tcx, cache, &key, |value, index| {
         tcx.dep_graph().read_index(index);
         value.clone()
     });
     match cached {
         Ok(value) => value,
-        Err(lookup) => try_execute_query(tcx, state, span, key, lookup, query),
+        Err(lookup) => try_execute_query(tcx, state, cache, span, key, lookup, query),
     }
 }
 
@@ -667,7 +691,8 @@ where
 #[inline(never)]
 fn force_query_impl<CTX, C>(
     tcx: CTX,
-    state: &QueryState<CTX::DepKind, CTX::Query, C>,
+    state: &QueryState<CTX::DepKind, CTX::Query, C::Key>,
+    cache: &QueryCacheStore<C>,
     key: C::Key,
     span: Span,
     dep_node: DepNode<CTX::DepKind>,
@@ -680,7 +705,7 @@ fn force_query_impl<CTX, C>(
     // We may be concurrently trying both execute and force a query.
     // Ensure that only one of them runs the query.
 
-    let cached = try_get_cached(tcx, state, &key, |_, _| {
+    let cached = try_get_cached(tcx, cache, &key, |_, _| {
         // Cache hit, do nothing
     });
 
@@ -690,7 +715,7 @@ fn force_query_impl<CTX, C>(
     };
 
     let job = match JobOwner::<'_, CTX::DepKind, CTX::Query, C>::try_start(
-        tcx, state, span, &key, lookup, query,
+        tcx, state, cache, span, &key, lookup, query,
     ) {
         TryGetJob::NotYetStarted(job) => job,
         TryGetJob::Cycle(_) => return,
@@ -719,7 +744,7 @@ where
     }
 
     debug!("ty::query::get_query<{}>(key={:?}, span={:?})", Q::NAME, key, span);
-    let value = get_query_impl(tcx, Q::query_state(tcx), span, key, query);
+    let value = get_query_impl(tcx, Q::query_state(tcx), Q::query_cache(tcx), span, key, query);
     Some(value)
 }
 
@@ -729,5 +754,5 @@ where
     Q::Key: crate::dep_graph::DepNodeParams<CTX>,
     CTX: QueryContext,
 {
-    force_query_impl(tcx, Q::query_state(tcx), key, span, dep_node, &Q::VTABLE)
+    force_query_impl(tcx, Q::query_state(tcx), Q::query_cache(tcx), key, span, dep_node, &Q::VTABLE)
 }

From 280a2866d502747b51bd81390be760973c54e719 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sat, 6 Feb 2021 14:04:20 +0100
Subject: [PATCH 6/7] Drop the cache lock earlier.

---
 .../rustc_query_system/src/query/caches.rs    | 14 +++++++-------
 .../rustc_query_system/src/query/plumbing.rs  | 19 ++++++++++---------
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/compiler/rustc_query_system/src/query/caches.rs b/compiler/rustc_query_system/src/query/caches.rs
index d589c90fa7b12..ec71c8685804f 100644
--- a/compiler/rustc_query_system/src/query/caches.rs
+++ b/compiler/rustc_query_system/src/query/caches.rs
@@ -37,7 +37,7 @@ pub trait QueryCache: QueryStorage {
         key: &Self::Key,
         // `on_hit` can be called while holding a lock to the query state shard.
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
+    ) -> Result<R, QueryLookup>
     where
         OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R;
 
@@ -98,12 +98,12 @@ where
         state: &'s QueryCacheStore<Self>,
         key: &K,
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
+    ) -> Result<R, QueryLookup>
     where
         OnHit: FnOnce(&V, DepNodeIndex) -> R,
     {
-        let lookup = state.get_lookup(key);
-        let result = lookup.lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
+        let (lookup, lock) = state.get_lookup(key);
+        let result = lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
         if let Some((_, value)) = result {
             let hit_result = on_hit(&value.0, value.1);
@@ -181,12 +181,12 @@ where
         state: &'s QueryCacheStore<Self>,
         key: &K,
         on_hit: OnHit,
-    ) -> Result<R, QueryLookup<'s, Self::Sharded>>
+    ) -> Result<R, QueryLookup>
     where
         OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
     {
-        let lookup = state.get_lookup(key);
-        let result = lookup.lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
+        let (lookup, lock) = state.get_lookup(key);
+        let result = lock.raw_entry().from_key_hashed_nocheck(lookup.key_hash, key);
 
         if let Some((_, value)) = result {
             let hit_result = on_hit(&&value.0, value.1);
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index 51a72594b5e0c..c2e89e131b3fe 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -46,10 +46,9 @@ impl<C: QueryCache> Default for QueryCacheStore<C> {
 }
 
 /// Values used when checking a query cache which can be reused on a cache-miss to execute the query.
-pub struct QueryLookup<'tcx, C> {
+pub struct QueryLookup {
     pub(super) key_hash: u64,
     shard: usize,
-    pub(super) lock: LockGuard<'tcx, C>,
 }
 
 // We compute the key's hash once and then use it for both the
@@ -62,11 +61,14 @@ fn hash_for_shard<K: Hash>(key: &K) -> u64 {
 }
 
 impl<C: QueryCache> QueryCacheStore<C> {
-    pub(super) fn get_lookup<'tcx>(&'tcx self, key: &C::Key) -> QueryLookup<'tcx, C::Sharded> {
+    pub(super) fn get_lookup<'tcx>(
+        &'tcx self,
+        key: &C::Key,
+    ) -> (QueryLookup, LockGuard<'tcx, C::Sharded>) {
         let key_hash = hash_for_shard(key);
         let shard = get_shard_index_by_hash(key_hash);
         let lock = self.shards.get_shard_by_index(shard).lock();
-        QueryLookup { key_hash, shard, lock }
+        (QueryLookup { key_hash, shard }, lock)
     }
 
     pub fn iter_results<R>(
@@ -178,19 +180,18 @@ where
     /// This function is inlined because that results in a noticeable speed-up
     /// for some compile-time benchmarks.
     #[inline(always)]
-    fn try_start<'a, 'b, CTX>(
+    fn try_start<'b, CTX>(
         tcx: CTX,
         state: &'b QueryState<CTX::DepKind, CTX::Query, C::Key>,
         cache: &'b QueryCacheStore<C>,
         span: Span,
         key: &C::Key,
-        lookup: QueryLookup<'a, C::Sharded>,
+        lookup: QueryLookup,
         query: &QueryVtable<CTX, C::Key, C::Value>,
     ) -> TryGetJob<'b, CTX::DepKind, CTX::Query, C>
     where
         CTX: QueryContext,
     {
-        mem::drop(lookup.lock);
         let shard = lookup.shard;
         let mut state_lock = state.shards.get_shard_by_index(shard).lock();
         let lock = &mut *state_lock;
@@ -379,7 +380,7 @@ fn try_get_cached<'a, CTX, C, R, OnHit>(
     key: &C::Key,
     // `on_hit` can be called while holding a lock to the query cache
     on_hit: OnHit,
-) -> Result<R, QueryLookup<'a, C::Sharded>>
+) -> Result<R, QueryLookup>
 where
     C: QueryCache,
     CTX: QueryContext,
@@ -403,7 +404,7 @@ fn try_execute_query<CTX, C>(
     cache: &QueryCacheStore<C>,
     span: Span,
     key: C::Key,
-    lookup: QueryLookup<'_, C::Sharded>,
+    lookup: QueryLookup,
     query: &QueryVtable<CTX, C::Key, C::Value>,
 ) -> C::Stored
 where

From 3fc8ed68e99034ad5410cef47e8cd94828ef8946 Mon Sep 17 00:00:00 2001
From: Camille GILLOT <gillot.camille@gmail.com>
Date: Sat, 6 Feb 2021 14:52:04 +0100
Subject: [PATCH 7/7] Check query cache before calling into the query engine.

---
 .../rustc_middle/src/ty/query/plumbing.rs     | 24 +++++++--
 .../rustc_query_system/src/query/plumbing.rs  | 52 +++++++++++++------
 2 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/query/plumbing.rs b/compiler/rustc_middle/src/ty/query/plumbing.rs
index 9a011846fd62d..0961d4d0091d0 100644
--- a/compiler/rustc_middle/src/ty/query/plumbing.rs
+++ b/compiler/rustc_middle/src/ty/query/plumbing.rs
@@ -422,7 +422,15 @@ macro_rules! define_queries {
             $($(#[$attr])*
             #[inline(always)]
             pub fn $name(self, key: query_helper_param_ty!($($K)*)) {
-                get_query::<queries::$name<'_>, _>(self.tcx, DUMMY_SP, key.into_query_param(), QueryMode::Ensure);
+                let key = key.into_query_param();
+                let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, |_| {});
+
+                let lookup = match cached {
+                    Ok(()) => return,
+                    Err(lookup) => lookup,
+                };
+
+                get_query::<queries::$name<'_>, _>(self.tcx, DUMMY_SP, key, lookup, QueryMode::Ensure);
             })*
         }
 
@@ -465,7 +473,7 @@ macro_rules! define_queries {
             #[must_use]
             pub fn $name(self, key: query_helper_param_ty!($($K)*)) -> query_stored::$name<$tcx>
             {
-                self.at(DUMMY_SP).$name(key.into_query_param())
+                self.at(DUMMY_SP).$name(key)
             })*
 
             /// All self-profiling events generated by the query engine use
@@ -503,7 +511,17 @@ macro_rules! define_queries {
             #[inline(always)]
             pub fn $name(self, key: query_helper_param_ty!($($K)*)) -> query_stored::$name<$tcx>
             {
-                get_query::<queries::$name<'_>, _>(self.tcx, self.span, key.into_query_param(), QueryMode::Get).unwrap()
+                let key = key.into_query_param();
+                let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, |value| {
+                    value.clone()
+                });
+
+                let lookup = match cached {
+                    Ok(value) => return value,
+                    Err(lookup) => lookup,
+                };
+
+                get_query::<queries::$name<'_>, _>(self.tcx, self.span, key, lookup, QueryMode::Get).unwrap()
             })*
         }
 
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index c2e89e131b3fe..2610ce83e4d3e 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -263,7 +263,18 @@ where
                 return TryGetJob::Cycle(value);
             }
 
-            let cached = try_get_cached(tcx, cache, key, |value, index| (value.clone(), index))
+            let cached = cache
+                .cache
+                .lookup(cache, &key, |value, index| {
+                    if unlikely!(tcx.profiler().enabled()) {
+                        tcx.profiler().query_cache_hit(index.into());
+                    }
+                    #[cfg(debug_assertions)]
+                    {
+                        cache.cache_hits.fetch_add(1, Ordering::Relaxed);
+                    }
+                    (value.clone(), index)
+                })
                 .unwrap_or_else(|_| panic!("value must be in cache after waiting"));
 
             if let Some(prof_timer) = _query_blocked_prof_timer.take() {
@@ -374,7 +385,7 @@ where
 /// It returns the shard index and a lock guard to the shard,
 /// which will be used if the query is not in the cache and we need
 /// to compute it.
-fn try_get_cached<'a, CTX, C, R, OnHit>(
+pub fn try_get_cached<'a, CTX, C, R, OnHit>(
     tcx: CTX,
     cache: &'a QueryCacheStore<C>,
     key: &C::Key,
@@ -384,7 +395,7 @@ fn try_get_cached<'a, CTX, C, R, OnHit>(
 where
     C: QueryCache,
     CTX: QueryContext,
-    OnHit: FnOnce(&C::Stored, DepNodeIndex) -> R,
+    OnHit: FnOnce(&C::Stored) -> R,
 {
     cache.cache.lookup(cache, &key, |value, index| {
         if unlikely!(tcx.profiler().enabled()) {
@@ -394,7 +405,8 @@ where
         {
             cache.cache_hits.fetch_add(1, Ordering::Relaxed);
         }
-        on_hit(value, index)
+        tcx.dep_graph().read_index(index);
+        on_hit(value)
     })
 }
 
@@ -632,6 +644,7 @@ fn get_query_impl<CTX, C>(
     cache: &QueryCacheStore<C>,
     span: Span,
     key: C::Key,
+    lookup: QueryLookup,
     query: &QueryVtable<CTX, C::Key, C::Value>,
 ) -> C::Stored
 where
@@ -639,14 +652,7 @@ where
     C: QueryCache,
     C::Key: crate::dep_graph::DepNodeParams<CTX>,
 {
-    let cached = try_get_cached(tcx, cache, &key, |value, index| {
-        tcx.dep_graph().read_index(index);
-        value.clone()
-    });
-    match cached {
-        Ok(value) => value,
-        Err(lookup) => try_execute_query(tcx, state, cache, span, key, lookup, query),
-    }
+    try_execute_query(tcx, state, cache, span, key, lookup, query)
 }
 
 /// Ensure that either this query has all green inputs or been executed.
@@ -705,9 +711,14 @@ fn force_query_impl<CTX, C>(
 {
     // We may be concurrently trying both execute and force a query.
     // Ensure that only one of them runs the query.
-
-    let cached = try_get_cached(tcx, cache, &key, |_, _| {
-        // Cache hit, do nothing
+    let cached = cache.cache.lookup(cache, &key, |_, index| {
+        if unlikely!(tcx.profiler().enabled()) {
+            tcx.profiler().query_cache_hit(index.into());
+        }
+        #[cfg(debug_assertions)]
+        {
+            cache.cache_hits.fetch_add(1, Ordering::Relaxed);
+        }
     });
 
     let lookup = match cached {
@@ -731,7 +742,13 @@ pub enum QueryMode {
     Ensure,
 }
 
-pub fn get_query<Q, CTX>(tcx: CTX, span: Span, key: Q::Key, mode: QueryMode) -> Option<Q::Stored>
+pub fn get_query<Q, CTX>(
+    tcx: CTX,
+    span: Span,
+    key: Q::Key,
+    lookup: QueryLookup,
+    mode: QueryMode,
+) -> Option<Q::Stored>
 where
     Q: QueryDescription<CTX>,
     Q::Key: crate::dep_graph::DepNodeParams<CTX>,
@@ -745,7 +762,8 @@ where
     }
 
     debug!("ty::query::get_query<{}>(key={:?}, span={:?})", Q::NAME, key, span);
-    let value = get_query_impl(tcx, Q::query_state(tcx), Q::query_cache(tcx), span, key, query);
+    let value =
+        get_query_impl(tcx, Q::query_state(tcx), Q::query_cache(tcx), span, key, lookup, query);
     Some(value)
 }