Skip to content

Commit baadb5a

Browse files
authored
[ty] Add some additional type safety to CycleDetector (#19903)
This PR adds a type tag to the `CycleDetector` visitor (and its aliases). There are some places where we implement e.g. an equivalence check by making a disjointness check. Both `is_equivalent_to` and `is_disjoint_from` use a `PairVisitor` to handle cycles, but they should not use the same visitor. I was finding it tedious to remember when it was appropriate to pass on a visitor and when not to. This adds a `PhantomData` type tag to ensure that we can't pass on one method's visitor to a different method. For `has_relation` and `apply_type_mapping`, we have an existing type that we can use as the tag. For the other methods, I've added empty structs (`Normalized`, `IsDisjointFrom`, `IsEquivalentTo`) to use as tags.
1 parent df0648a commit baadb5a

File tree

11 files changed

+155
-105
lines changed

11 files changed

+155
-105
lines changed

crates/ty_python_semantic/src/types.rs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,24 @@ fn definition_expression_type<'db>(
170170
}
171171
}
172172

173+
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
174+
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
175+
176+
/// A [`PairVisitor`] that is used in `has_relation_to` methods.
177+
pub(crate) type HasRelationToVisitor<'db> = PairVisitor<'db, TypeRelation>;
178+
179+
/// A [`PairVisitor`] that is used in `is_disjoint_from` methods.
180+
pub(crate) type IsDisjointVisitor<'db> = PairVisitor<'db, IsDisjoint>;
181+
pub(crate) struct IsDisjoint;
182+
183+
/// A [`PairVisitor`] that is used in `is_equivalent` methods.
184+
pub(crate) type IsEquivalentVisitor<'db> = PairVisitor<'db, IsEquivalent>;
185+
pub(crate) struct IsEquivalent;
186+
187+
/// A [`TypeTransformer`] that is used in `normalized` methods.
188+
pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>;
189+
pub(crate) struct Normalized;
190+
173191
/// The descriptor protocol distinguishes two kinds of descriptors. Non-data descriptors
174192
/// define a `__get__` method, while data descriptors additionally define a `__set__`
175193
/// method or a `__delete__` method. This enum is used to categorize attributes into two
@@ -419,7 +437,7 @@ impl<'db> PropertyInstanceType<'db> {
419437
Self::new(db, getter, setter)
420438
}
421439

422-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
440+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
423441
Self::new(
424442
db,
425443
self.getter(db).map(|ty| ty.normalized_impl(db, visitor)),
@@ -1068,7 +1086,7 @@ impl<'db> Type<'db> {
10681086
}
10691087

10701088
#[must_use]
1071-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
1089+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
10721090
match self {
10731091
Type::Union(union) => {
10741092
visitor.visit(self, || Type::Union(union.normalized_impl(db, visitor)))
@@ -1326,7 +1344,7 @@ impl<'db> Type<'db> {
13261344
db: &'db dyn Db,
13271345
target: Type<'db>,
13281346
relation: TypeRelation,
1329-
visitor: &PairVisitor<'db>,
1347+
visitor: &HasRelationToVisitor<'db>,
13301348
) -> bool {
13311349
// Subtyping implies assignability, so if subtyping is reflexive and the two types are
13321350
// equal, it is both a subtype and assignable. Assignability is always reflexive.
@@ -1762,7 +1780,7 @@ impl<'db> Type<'db> {
17621780
self,
17631781
db: &'db dyn Db,
17641782
other: Type<'db>,
1765-
visitor: &PairVisitor<'db>,
1783+
visitor: &IsEquivalentVisitor<'db>,
17661784
) -> bool {
17671785
if self == other {
17681786
return true;
@@ -1848,13 +1866,13 @@ impl<'db> Type<'db> {
18481866
self,
18491867
db: &'db dyn Db,
18501868
other: Type<'db>,
1851-
visitor: &PairVisitor<'db>,
1869+
visitor: &IsDisjointVisitor<'db>,
18521870
) -> bool {
18531871
fn any_protocol_members_absent_or_disjoint<'db>(
18541872
db: &'db dyn Db,
18551873
protocol: ProtocolInstanceType<'db>,
18561874
other: Type<'db>,
1857-
visitor: &PairVisitor<'db>,
1875+
visitor: &IsDisjointVisitor<'db>,
18581876
) -> bool {
18591877
protocol.interface(db).members(db).any(|member| {
18601878
other
@@ -5743,7 +5761,7 @@ impl<'db> Type<'db> {
57435761
self,
57445762
db: &'db dyn Db,
57455763
type_mapping: &TypeMapping<'a, 'db>,
5746-
visitor: &TypeTransformer<'db>,
5764+
visitor: &ApplyTypeMappingVisitor<'db>,
57475765
) -> Type<'db> {
57485766
match self {
57495767
Type::TypeVar(bound_typevar) => match type_mapping {
@@ -6265,7 +6283,7 @@ impl<'db> TypeMapping<'_, 'db> {
62656283
}
62666284
}
62676285

6268-
fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6286+
fn normalized_impl(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
62696287
match self {
62706288
TypeMapping::Specialization(specialization) => {
62716289
TypeMapping::Specialization(specialization.normalized_impl(db, visitor))
@@ -6351,7 +6369,7 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
63516369
}
63526370

63536371
impl<'db> KnownInstanceType<'db> {
6354-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6372+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
63556373
match self {
63566374
Self::SubscriptedProtocol(context) => {
63576375
Self::SubscriptedProtocol(context.normalized_impl(db, visitor))
@@ -6777,7 +6795,7 @@ pub struct FieldInstance<'db> {
67776795
impl get_size2::GetSize for FieldInstance<'_> {}
67786796

67796797
impl<'db> FieldInstance<'db> {
6780-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6798+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
67816799
FieldInstance::new(
67826800
db,
67836801
self.default_type(db).normalized_impl(db, visitor),
@@ -6901,7 +6919,7 @@ impl<'db> TypeVarInstance<'db> {
69016919
}
69026920
}
69036921

6904-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6922+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
69056923
Self::new(
69066924
db,
69076925
self.name(db),
@@ -6997,7 +7015,7 @@ impl<'db> BoundTypeVarInstance<'db> {
69977015
.map(|ty| ty.apply_type_mapping(db, &TypeMapping::BindLegacyTypevars(binding_context)))
69987016
}
69997017

7000-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
7018+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
70017019
Self::new(
70027020
db,
70037021
self.typevar(db).normalized_impl(db, visitor),
@@ -7056,7 +7074,7 @@ fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
70567074
}
70577075

70587076
impl<'db> TypeVarBoundOrConstraints<'db> {
7059-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
7077+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
70607078
match self {
70617079
TypeVarBoundOrConstraints::UpperBound(bound) => {
70627080
TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor))
@@ -8094,7 +8112,7 @@ impl<'db> BoundMethodType<'db> {
80948112
)
80958113
}
80968114

8097-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8115+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
80988116
Self::new(
80998117
db,
81008118
self.function(db).normalized_impl(db, visitor),
@@ -8211,7 +8229,7 @@ impl<'db> CallableType<'db> {
82118229
/// Return a "normalized" version of this `Callable` type.
82128230
///
82138231
/// See [`Type::normalized`] for more details.
8214-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8232+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
82158233
CallableType::new(
82168234
db,
82178235
self.signatures(db).normalized_impl(db, visitor),
@@ -8375,7 +8393,7 @@ impl<'db> MethodWrapperKind<'db> {
83758393
}
83768394
}
83778395

8378-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8396+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
83798397
match self {
83808398
MethodWrapperKind::FunctionTypeDunderGet(function) => {
83818399
MethodWrapperKind::FunctionTypeDunderGet(function.normalized_impl(db, visitor))
@@ -8559,7 +8577,7 @@ impl<'db> PEP695TypeAliasType<'db> {
85598577
definition_expression_type(db, definition, &type_alias_stmt_node.value)
85608578
}
85618579

8562-
fn normalized_impl(self, _db: &'db dyn Db, _visitor: &TypeTransformer<'db>) -> Self {
8580+
fn normalized_impl(self, _db: &'db dyn Db, _visitor: &NormalizedVisitor<'db>) -> Self {
85638581
self
85648582
}
85658583
}
@@ -8601,7 +8619,7 @@ fn walk_bare_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
86018619
}
86028620

86038621
impl<'db> BareTypeAliasType<'db> {
8604-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8622+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
86058623
Self::new(
86068624
db,
86078625
self.name(db),
@@ -8637,7 +8655,7 @@ fn walk_type_alias_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
86378655
}
86388656

86398657
impl<'db> TypeAliasType<'db> {
8640-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8658+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
86418659
match self {
86428660
TypeAliasType::PEP695(type_alias) => {
86438661
TypeAliasType::PEP695(type_alias.normalized_impl(db, visitor))
@@ -8866,7 +8884,7 @@ impl<'db> UnionType<'db> {
88668884
self.normalized_impl(db, &TypeTransformer::default())
88678885
}
88688886

8869-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8887+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
88708888
let mut new_elements: Vec<Type<'db>> = self
88718889
.elements(db)
88728890
.iter()
@@ -8940,11 +8958,11 @@ impl<'db> IntersectionType<'db> {
89408958
self.normalized_impl(db, &TypeTransformer::default())
89418959
}
89428960

8943-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8961+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
89448962
fn normalized_set<'db>(
89458963
db: &'db dyn Db,
89468964
elements: &FxOrderSet<Type<'db>>,
8947-
visitor: &TypeTransformer<'db>,
8965+
visitor: &NormalizedVisitor<'db>,
89488966
) -> FxOrderSet<Type<'db>> {
89498967
let mut elements: FxOrderSet<Type<'db>> = elements
89508968
.iter()
@@ -9194,7 +9212,7 @@ impl<'db> TypedDictType<'db> {
91949212
self,
91959213
db: &'db dyn Db,
91969214
type_mapping: &TypeMapping<'a, 'db>,
9197-
visitor: &TypeTransformer<'db>,
9215+
visitor: &ApplyTypeMappingVisitor<'db>,
91989216
) -> Self {
91999217
Self {
92009218
defining_class: self
@@ -9266,7 +9284,7 @@ pub enum SuperOwnerKind<'db> {
92669284
}
92679285

92689286
impl<'db> SuperOwnerKind<'db> {
9269-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9287+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
92709288
match self {
92719289
SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized()),
92729290
SuperOwnerKind::Class(class) => {
@@ -9538,7 +9556,7 @@ impl<'db> BoundSuperType<'db> {
95389556
}
95399557
}
95409558

9541-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9559+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
95429560
Self::new(
95439561
db,
95449562
self.pivot_class(db).normalized_impl(db, visitor),

crates/ty_python_semantic/src/types/class.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use crate::types::infer::nearest_enclosing_class;
2222
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
2323
use crate::types::tuple::{TupleSpec, TupleType};
2424
use crate::types::{
25-
BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams,
26-
DeprecatedInstance, KnownInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
27-
TypeRelation, TypeTransformer, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind,
28-
declaration_type, infer_definition_types, todo_type,
25+
ApplyTypeMappingVisitor, BareTypeAliasType, Binding, BoundSuperError, BoundSuperType,
26+
CallableType, DataclassParams, DeprecatedInstance, HasRelationToVisitor, KnownInstanceType,
27+
NormalizedVisitor, StringLiteralType, TypeAliasType, TypeMapping, TypeRelation,
28+
TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, declaration_type,
29+
infer_definition_types, todo_type,
2930
};
3031
use crate::{
3132
Db, FxIndexMap, FxOrderSet, Program,
@@ -231,7 +232,7 @@ pub(super) fn walk_generic_alias<'db, V: super::visitor::TypeVisitor<'db> + ?Siz
231232
impl get_size2::GetSize for GenericAlias<'_> {}
232233

233234
impl<'db> GenericAlias<'db> {
234-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
235+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
235236
Self::new(
236237
db,
237238
self.origin(db),
@@ -255,7 +256,7 @@ impl<'db> GenericAlias<'db> {
255256
self,
256257
db: &'db dyn Db,
257258
type_mapping: &TypeMapping<'a, 'db>,
258-
visitor: &TypeTransformer<'db>,
259+
visitor: &ApplyTypeMappingVisitor<'db>,
259260
) -> Self {
260261
Self::new(
261262
db,
@@ -319,7 +320,7 @@ impl<'db> ClassType<'db> {
319320
}
320321
}
321322

322-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
323+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
323324
match self {
324325
Self::NonGeneric(_) => self,
325326
Self::Generic(generic) => Self::Generic(generic.normalized_impl(db, visitor)),
@@ -406,7 +407,7 @@ impl<'db> ClassType<'db> {
406407
self,
407408
db: &'db dyn Db,
408409
type_mapping: &TypeMapping<'a, 'db>,
409-
visitor: &TypeTransformer<'db>,
410+
visitor: &ApplyTypeMappingVisitor<'db>,
410411
) -> Self {
411412
match self {
412413
Self::NonGeneric(_) => self,
@@ -469,7 +470,7 @@ impl<'db> ClassType<'db> {
469470
db: &'db dyn Db,
470471
other: Self,
471472
relation: TypeRelation,
472-
visitor: &PairVisitor<'db>,
473+
visitor: &HasRelationToVisitor<'db>,
473474
) -> bool {
474475
self.iter_mro(db).any(|base| {
475476
match base {

crates/ty_python_semantic/src/types/class_base.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::Db;
22
use crate::types::generics::Specialization;
33
use crate::types::tuple::TupleType;
44
use crate::types::{
5-
ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType, MroError, MroIterator,
6-
SpecialFormType, Type, TypeMapping, TypeTransformer, todo_type,
5+
ApplyTypeMappingVisitor, ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType,
6+
MroError, MroIterator, NormalizedVisitor, SpecialFormType, Type, TypeMapping, TypeTransformer,
7+
todo_type,
78
};
89

910
/// Enumeration of the possible kinds of types we allow in class bases.
@@ -33,7 +34,7 @@ impl<'db> ClassBase<'db> {
3334
Self::Dynamic(DynamicType::Unknown)
3435
}
3536

36-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
37+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
3738
match self {
3839
Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()),
3940
Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)),
@@ -269,7 +270,7 @@ impl<'db> ClassBase<'db> {
269270
self,
270271
db: &'db dyn Db,
271272
type_mapping: &TypeMapping<'a, 'db>,
272-
visitor: &TypeTransformer<'db>,
273+
visitor: &ApplyTypeMappingVisitor<'db>,
273274
) -> Self {
274275
match self {
275276
Self::Class(class) => {

crates/ty_python_semantic/src/types/cyclic.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818
//! `visitor.visit` when visiting a protocol type, and then internal `has_relation_to_impl` methods
1919
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
2020
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
21-
use rustc_hash::FxHashMap;
2221
23-
use crate::FxIndexSet;
24-
use crate::types::Type;
2522
use std::cell::RefCell;
2623
use std::cmp::Eq;
2724
use std::hash::Hash;
25+
use std::marker::PhantomData;
2826

29-
pub(crate) type TypeTransformer<'db> = CycleDetector<Type<'db>, Type<'db>>;
27+
use rustc_hash::FxHashMap;
3028

31-
impl Default for TypeTransformer<'_> {
29+
use crate::FxIndexSet;
30+
use crate::types::Type;
31+
32+
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
33+
34+
impl<Tag> Default for TypeTransformer<'_, Tag> {
3235
fn default() -> Self {
3336
// TODO: proper recursive type handling
3437

@@ -38,10 +41,10 @@ impl Default for TypeTransformer<'_> {
3841
}
3942
}
4043

41-
pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
44+
pub(crate) type PairVisitor<'db, Tag> = CycleDetector<Tag, (Type<'db>, Type<'db>), bool>;
4245

4346
#[derive(Debug)]
44-
pub(crate) struct CycleDetector<T, R> {
47+
pub(crate) struct CycleDetector<Tag, T, R> {
4548
/// If the type we're visiting is present in `seen`, it indicates that we've hit a cycle (due
4649
/// to a recursive type); we need to immediately short circuit the whole operation and return
4750
/// the fallback value. That's why we pop items off the end of `seen` after we've visited them.
@@ -56,14 +59,17 @@ pub(crate) struct CycleDetector<T, R> {
5659
cache: RefCell<FxHashMap<T, R>>,
5760

5861
fallback: R,
62+
63+
_tag: PhantomData<Tag>,
5964
}
6065

61-
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
66+
impl<Tag, T: Hash + Eq + Copy, R: Copy> CycleDetector<Tag, T, R> {
6267
pub(crate) fn new(fallback: R) -> Self {
6368
CycleDetector {
6469
seen: RefCell::new(FxIndexSet::default()),
6570
cache: RefCell::new(FxHashMap::default()),
6671
fallback,
72+
_tag: PhantomData,
6773
}
6874
}
6975

0 commit comments

Comments
 (0)