From d2834525ba9378bb69214f3c3f7e868b9f6d4880 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sat, 11 Mar 2023 22:05:50 +0000
Subject: [PATCH] Use TyCtxt::trait_solver_next in some places

---
 .../src/traits/query/evaluate_obligation.rs       | 15 +++++++--------
 .../src/traits/select/mod.rs                      | 13 ++++++-------
 2 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs b/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
index 1420c25c92280..f84b2f4428d1a 100644
--- a/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
+++ b/compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs
@@ -1,6 +1,5 @@
 use rustc_middle::traits::solve::{Certainty, Goal, MaybeCause};
 use rustc_middle::ty;
-use rustc_session::config::TraitSolver;
 
 use crate::infer::canonical::OriginalQueryValues;
 use crate::infer::InferCtxt;
@@ -80,13 +79,7 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
             _ => obligation.param_env.without_const(),
         };
 
-        if self.tcx.sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
-            let c_pred = self.canonicalize_query_keep_static(
-                param_env.and(obligation.predicate),
-                &mut _orig_values,
-            );
-            self.tcx.at(obligation.cause.span()).evaluate_obligation(c_pred)
-        } else {
+        if self.tcx.trait_solver_next() {
             self.probe(|snapshot| {
                 if let Ok((_, certainty)) =
                     self.evaluate_root_goal(Goal::new(self.tcx, param_env, obligation.predicate))
@@ -111,6 +104,12 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
                     Ok(EvaluationResult::EvaluatedToErr)
                 }
             })
+        } else {
+            let c_pred = self.canonicalize_query_keep_static(
+                param_env.and(obligation.predicate),
+                &mut _orig_values,
+            );
+            self.tcx.at(obligation.cause.span()).evaluate_obligation(c_pred)
         }
     }
 
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index ba4df2b3720ef..d7ce007812450 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -50,7 +50,6 @@ use rustc_middle::ty::relate::TypeRelation;
 use rustc_middle::ty::SubstsRef;
 use rustc_middle::ty::{self, EarlyBinder, PolyProjectionPredicate, ToPolyTraitRef, ToPredicate};
 use rustc_middle::ty::{Ty, TyCtxt, TypeFoldable, TypeVisitableExt};
-use rustc_session::config::TraitSolver;
 use rustc_span::symbol::sym;
 
 use std::cell::{Cell, RefCell};
@@ -545,13 +544,13 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         obligation: &PredicateObligation<'tcx>,
     ) -> Result<EvaluationResult, OverflowError> {
         self.evaluation_probe(|this| {
-            if this.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+            if this.tcx().trait_solver_next() {
+                this.evaluate_predicates_recursively_in_new_solver([obligation.clone()])
+            } else {
                 this.evaluate_predicate_recursively(
                     TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
                     obligation.clone(),
                 )
-            } else {
-                this.evaluate_predicates_recursively_in_new_solver([obligation.clone()])
             }
         })
     }
@@ -591,7 +590,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
     where
         I: IntoIterator<Item = PredicateObligation<'tcx>> + std::fmt::Debug,
     {
-        if self.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+        if self.tcx().trait_solver_next() {
+            self.evaluate_predicates_recursively_in_new_solver(predicates)
+        } else {
             let mut result = EvaluatedToOk;
             for obligation in predicates {
                 let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
@@ -604,8 +605,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                 }
             }
             Ok(result)
-        } else {
-            self.evaluate_predicates_recursively_in_new_solver(predicates)
         }
     }