Skip to content

refactor infer var storage #118742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions compiler/rustc_hir_typeck/src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
self.fulfillment_cx.borrow_mut().pending_obligations()
);

let fallback_occured = self.fallback_types() | self.fallback_effects();
let fallback_occurred = self.fallback_types() | self.fallback_effects();

if !fallback_occured {
if !fallback_occurred {
return;
}

Expand Down Expand Up @@ -57,24 +57,25 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
}

fn fallback_types(&self) -> bool {
// Check if we have any unsolved variables. If not, no need for fallback.
let unsolved_variables = self.unsolved_variables();
// Check if we have any unresolved variables. If not, no need for fallback.
let unresolved_variables = self.unresolved_variables();

if unsolved_variables.is_empty() {
if unresolved_variables.is_empty() {
return false;
}

let diverging_fallback = self.calculate_diverging_fallback(&unsolved_variables);
let diverging_fallback = self.calculate_diverging_fallback(&unresolved_variables);

// We do fallback in two passes, to try to generate
// better error messages.
// The first time, we do *not* replace opaque types.
for ty in unsolved_variables {
let mut fallback_occurred = false;
for ty in unresolved_variables {
debug!("unsolved_variable = {:?}", ty);
self.fallback_if_possible(ty, &diverging_fallback);
fallback_occurred |= self.fallback_if_possible(ty, &diverging_fallback);
}

true
fallback_occurred
}

fn fallback_effects(&self) -> bool {
Expand All @@ -84,9 +85,8 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
return false;
}

// not setting `fallback_has_occured` here because that field is only used for type fallback
// diagnostics.

// not setting the `fallback_has_occured` field here because
// that field is only used for type fallback diagnostics.
for effect in unsolved_effects {
let expected = self.tcx.consts.true_;
let cause = self.misc(rustc_span::DUMMY_SP);
Expand Down Expand Up @@ -122,7 +122,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
&self,
ty: Ty<'tcx>,
diverging_fallback: &UnordMap<Ty<'tcx>, Ty<'tcx>>,
) {
) -> bool {
// Careful: we do NOT shallow-resolve `ty`. We know that `ty`
// is an unsolved variable, and we determine its fallback
// based solely on how it was created, not what other type
Expand All @@ -147,7 +147,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64,
_ => match diverging_fallback.get(&ty) {
Some(&fallback_ty) => fallback_ty,
None => return,
None => return false,
},
};
debug!("fallback_if_possible(ty={:?}): defaulting to `{:?}`", ty, fallback);
Expand All @@ -159,6 +159,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
.unwrap_or(rustc_span::DUMMY_SP);
self.demand_eqtype(span, ty, fallback);
self.fallback_has_occurred.set(true);
true
}

/// The "diverging fallback" system is rather complicated. This is
Expand Down Expand Up @@ -230,17 +231,17 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
/// any variable that has an edge into `D`.
fn calculate_diverging_fallback(
&self,
unsolved_variables: &[Ty<'tcx>],
unresolved_variables: &[Ty<'tcx>],
) -> UnordMap<Ty<'tcx>, Ty<'tcx>> {
debug!("calculate_diverging_fallback({:?})", unsolved_variables);
debug!("calculate_diverging_fallback({:?})", unresolved_variables);

// Construct a coercion graph where an edge `A -> B` indicates
// a type variable is that is coerced
let coercion_graph = self.create_coercion_graph();

// Extract the unsolved type inference variable vids; note that some
// unsolved variables are integer/float variables and are excluded.
let unsolved_vids = unsolved_variables.iter().filter_map(|ty| ty.ty_vid());
let unsolved_vids = unresolved_variables.iter().filter_map(|ty| ty.ty_vid());

// Compute the diverging root vids D -- that is, the root vid of
// those type variables that (a) are the target of a coercion from
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_infer/src/infer/canonical/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ impl<'tcx> InferCtxt<'tcx> {
}

fn take_opaque_types_for_query_response(&self) -> Vec<(ty::OpaqueTypeKey<'tcx>, Ty<'tcx>)> {
std::mem::take(&mut self.inner.borrow_mut().opaque_type_storage.opaque_types)
.into_iter()
.map(|(k, v)| (k, v.hidden_type.ty))
.collect()
self.take_opaque_types().into_iter().map(|(k, v)| (k, v.hidden_type.ty)).collect()
}

/// Given the (canonicalized) result to a canonical query,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/generalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ where
ty::Covariant | ty::Contravariant => (),
}

let origin = *inner.type_variables().var_origin(vid);
let origin = inner.type_variables().var_origin(vid);
let new_var_id =
inner.type_variables().new_var(self.for_universe, origin);
let u = Ty::new_var(self.tcx(), new_var_id);
Expand Down
20 changes: 8 additions & 12 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ pub(crate) type UnificationTable<'a, 'tcx, T> = ut::UnificationTable<
/// call to `start_snapshot` and `rollback_to`.
#[derive(Clone)]
pub struct InferCtxtInner<'tcx> {
undo_log: InferCtxtUndoLogs<'tcx>,

/// Cache for projections.
///
/// This cache is snapshotted along with the infcx.
Expand Down Expand Up @@ -162,18 +164,17 @@ pub struct InferCtxtInner<'tcx> {
/// that all type inference variables have been bound and so forth.
region_obligations: Vec<RegionObligation<'tcx>>,

undo_log: InferCtxtUndoLogs<'tcx>,

/// Caches for opaque type inference.
opaque_type_storage: OpaqueTypeStorage<'tcx>,
}

impl<'tcx> InferCtxtInner<'tcx> {
fn new() -> InferCtxtInner<'tcx> {
InferCtxtInner {
undo_log: InferCtxtUndoLogs::default(),

projection_cache: Default::default(),
type_variable_storage: type_variable::TypeVariableStorage::new(),
undo_log: InferCtxtUndoLogs::default(),
const_unification_storage: ut::UnificationTableStorage::new(),
int_unification_storage: ut::UnificationTableStorage::new(),
float_unification_storage: ut::UnificationTableStorage::new(),
Expand Down Expand Up @@ -759,7 +760,7 @@ impl<'tcx> InferCtxt<'tcx> {
pub fn type_var_origin(&self, ty: Ty<'tcx>) -> Option<TypeVariableOrigin> {
match *ty.kind() {
ty::Infer(ty::TyVar(vid)) => {
Some(*self.inner.borrow_mut().type_variables().var_origin(vid))
Some(self.inner.borrow_mut().type_variables().var_origin(vid))
}
_ => None,
}
Expand All @@ -769,11 +770,11 @@ impl<'tcx> InferCtxt<'tcx> {
freshen::TypeFreshener::new(self)
}

pub fn unsolved_variables(&self) -> Vec<Ty<'tcx>> {
pub fn unresolved_variables(&self) -> Vec<Ty<'tcx>> {
let mut inner = self.inner.borrow_mut();
let mut vars: Vec<Ty<'_>> = inner
.type_variables()
.unsolved_variables()
.unresolved_variables()
.into_iter()
.map(|t| Ty::new_var(self.tcx, t))
.collect();
Expand Down Expand Up @@ -1282,12 +1283,7 @@ impl<'tcx> InferCtxt<'tcx> {
pub fn region_var_origin(&self, vid: ty::RegionVid) -> RegionVariableOrigin {
let mut inner = self.inner.borrow_mut();
let inner = &mut *inner;
inner
.region_constraint_storage
.as_mut()
.expect("regions already resolved")
.with_log(&mut inner.undo_log)
.var_origin(vid)
inner.unwrap_region_constraints().var_origin(vid)
}

/// Clone the list of variable regions. This is used only during NLL processing
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<'a, 'tcx> TypeVisitor<TyCtxt<'tcx>> for UnresolvedTypeOrConstFinder<'a, 'tc
if let TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeParameterDefinition(_, _),
span,
} = *ty_vars.var_origin(ty_vid)
} = ty_vars.var_origin(ty_vid)
{
Some(span)
} else {
Expand Down
87 changes: 19 additions & 68 deletions compiler/rustc_infer/src/infer/type_variable.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rustc_hir::def_id::DefId;
use rustc_index::IndexVec;
use rustc_middle::ty::{self, Ty, TyVid};
use rustc_span::symbol::Symbol;
use rustc_span::Span;
Expand All @@ -11,14 +12,13 @@ use std::cmp;
use std::marker::PhantomData;
use std::ops::Range;

use rustc_data_structures::undo_log::{Rollback, UndoLogs};
use rustc_data_structures::undo_log::Rollback;

/// Represents a single undo-able action that affects a type inference variable.
#[derive(Clone)]
pub(crate) enum UndoLog<'tcx> {
EqRelation(sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>),
SubRelation(sv::UndoLog<ut::Delegate<ty::TyVid>>),
Values(sv::UndoLog<Delegate>),
}

/// Convert from a specific kind of undo to the more general UndoLog
Expand All @@ -35,34 +35,19 @@ impl<'tcx> From<sv::UndoLog<ut::Delegate<ty::TyVid>>> for UndoLog<'tcx> {
}
}

/// Convert from a specific kind of undo to the more general UndoLog
impl<'tcx> From<sv::UndoLog<Delegate>> for UndoLog<'tcx> {
fn from(l: sv::UndoLog<Delegate>) -> Self {
UndoLog::Values(l)
}
}

/// Convert from a specific kind of undo to the more general UndoLog
impl<'tcx> From<Instantiate> for UndoLog<'tcx> {
fn from(l: Instantiate) -> Self {
UndoLog::Values(sv::UndoLog::Other(l))
}
}

impl<'tcx> Rollback<UndoLog<'tcx>> for TypeVariableStorage<'tcx> {
fn reverse(&mut self, undo: UndoLog<'tcx>) {
match undo {
UndoLog::EqRelation(undo) => self.eq_relations.reverse(undo),
UndoLog::SubRelation(undo) => self.sub_relations.reverse(undo),
UndoLog::Values(undo) => self.values.reverse(undo),
}
}
}

#[derive(Clone)]
pub struct TypeVariableStorage<'tcx> {
values: sv::SnapshotVecStorage<Delegate>,

/// The origins of each type variable.
values: IndexVec<TyVid, TypeVariableData>,
/// Two variables are unified in `eq_relations` when we have a
/// constraint `?X == ?Y`. This table also stores, for each key,
/// the known value.
Expand Down Expand Up @@ -168,15 +153,10 @@ impl<'tcx> TypeVariableValue<'tcx> {
}
}

#[derive(Clone)]
pub(crate) struct Instantiate;

pub(crate) struct Delegate;

impl<'tcx> TypeVariableStorage<'tcx> {
pub fn new() -> TypeVariableStorage<'tcx> {
TypeVariableStorage {
values: sv::SnapshotVecStorage::new(),
values: Default::default(),
eq_relations: ut::UnificationTableStorage::new(),
sub_relations: ut::UnificationTableStorage::new(),
}
Expand All @@ -194,15 +174,20 @@ impl<'tcx> TypeVariableStorage<'tcx> {
pub(crate) fn eq_relations_ref(&self) -> &ut::UnificationTableStorage<TyVidEqKey<'tcx>> {
&self.eq_relations
}

pub(super) fn finalize_rollback(&mut self) {
debug_assert!(self.values.len() >= self.eq_relations.len());
self.values.truncate(self.eq_relations.len());
}
}

impl<'tcx> TypeVariableTable<'_, 'tcx> {
/// Returns the origin that was given when `vid` was created.
///
/// Note that this function does not return care whether
/// `vid` has been unified with something else or not.
pub fn var_origin(&self, vid: ty::TyVid) -> &TypeVariableOrigin {
&self.storage.values.get(vid.as_usize()).origin
pub fn var_origin(&self, vid: ty::TyVid) -> TypeVariableOrigin {
self.storage.values[vid].origin
}

/// Records that `a == b`, depending on `dir`.
Expand Down Expand Up @@ -237,11 +222,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
self.eq_relations().probe_value(vid)
);
self.eq_relations().union_value(vid, TypeVariableValue::Known { value: ty });

// Hack: we only need this so that `types_escaping_snapshot`
// can see what has been unified; see the Delegate impl for
// more details.
self.undo_log.push(Instantiate);
}

/// Creates a new type variable.
Expand All @@ -262,14 +242,14 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
let eq_key = self.eq_relations().new_key(TypeVariableValue::Unknown { universe });

let sub_key = self.sub_relations().new_key(());
assert_eq!(eq_key.vid, sub_key);
debug_assert_eq!(eq_key.vid, sub_key);

let index = self.values().push(TypeVariableData { origin });
assert_eq!(eq_key.vid.as_u32(), index as u32);
let index = self.storage.values.push(TypeVariableData { origin });
debug_assert_eq!(eq_key.vid, index);

debug!("new_var(index={:?}, universe={:?}, origin={:?})", eq_key.vid, universe, origin);

eq_key.vid
index
}

/// Returns the number of type variables created thus far.
Expand Down Expand Up @@ -329,13 +309,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
}
}

#[inline]
fn values(
&mut self,
) -> sv::SnapshotVec<Delegate, &mut Vec<TypeVariableData>, &mut InferCtxtUndoLogs<'tcx>> {
self.storage.values.with_log(self.undo_log)
}

#[inline]
fn eq_relations(&mut self) -> super::UnificationTable<'_, 'tcx, TyVidEqKey<'tcx>> {
self.storage.eq_relations.with_log(self.undo_log)
Expand All @@ -354,16 +327,14 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
let range = TyVid::from_usize(value_count)..TyVid::from_usize(self.num_vars());
(
range.start..range.end,
(range.start.as_usize()..range.end.as_usize())
.map(|index| self.storage.values.get(index).origin)
.collect(),
(range.start..range.end).map(|index| self.var_origin(index)).collect(),
)
}

/// Returns indices of all variables that are not yet
/// instantiated.
pub fn unsolved_variables(&mut self) -> Vec<ty::TyVid> {
(0..self.storage.values.len())
pub fn unresolved_variables(&mut self) -> Vec<ty::TyVid> {
(0..self.num_vars())
.filter_map(|i| {
let vid = ty::TyVid::from_usize(i);
match self.probe(vid) {
Expand All @@ -375,26 +346,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
}
}

impl sv::SnapshotVecDelegate for Delegate {
type Value = TypeVariableData;
type Undo = Instantiate;

fn reverse(_values: &mut Vec<TypeVariableData>, _action: Instantiate) {
// We don't actually have to *do* anything to reverse an
// instantiation; the value for a variable is stored in the
// `eq_relations` and hence its rollback code will handle
// it. In fact, we could *almost* just remove the
// `SnapshotVec` entirely, except that we would have to
// reproduce *some* of its logic, since we want to know which
// type variables have been instantiated since the snapshot
// was started, so we can implement `types_escaping_snapshot`.
//
// (If we extended the `UnificationTable` to let us see which
// values have been unified and so forth, that might also
// suffice.)
}
}

///////////////////////////////////////////////////////////////////////////

/// These structs (a newtyped TyVid) are used as the unification key
Expand Down
Loading