Skip to content

Commit 26a77f9

Browse files
authored
Avoid type size explosion when expanding types (#17842)
If TypedDict A has multiple items that refer to TypedDict B, don't duplicate the types representing B during type expansion (or generally when translating types). If TypedDicts are deeply nested, this could result in lot of redundant type objects. Example where this could matter (assume B is a big TypedDict): ``` class B(TypedDict): ... class A(TypedDict): a: B b: B c: B ... z: B ``` Also deduplicate large unions. It's common to have aliases that are defined as large unions, and again we want to avoid duplicating these unions. This may help with #17231, but this fix may not be sufficient.
1 parent 1995155 commit 26a77f9

File tree

7 files changed

+66
-8
lines changed

7 files changed

+66
-8
lines changed

mypy/applytype.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
216216
seen_aliases: frozenset[TypeInfo] = frozenset(),
217217
) -> None:
218+
super().__init__()
218219
self.poly_tvars = set(poly_tvars)
219220
# This is a simplified version of TypeVarScope used during semantic analysis.
220221
self.bound_tvars = bound_tvars

mypy/erasetype.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class TypeVarEraser(TypeTranslator):
161161
"""Implementation of type erasure"""
162162

163163
def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
164+
super().__init__()
164165
self.erase_id = erase_id
165166
self.replacement = replacement
166167

mypy/expandtype.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
179179
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
180180

181181
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
182+
super().__init__()
182183
self.variables = variables
183184
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
184185

@@ -454,15 +455,25 @@ def visit_tuple_type(self, t: TupleType) -> Type:
454455
return t.copy_modified(items=items, fallback=fallback)
455456

456457
def visit_typeddict_type(self, t: TypedDictType) -> Type:
458+
if cached := self.get_cached(t):
459+
return cached
457460
fallback = t.fallback.accept(self)
458461
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
459-
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
462+
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
463+
self.set_cached(t, result)
464+
return result
460465

461466
def visit_literal_type(self, t: LiteralType) -> Type:
462467
# TODO: Verify this implementation is correct
463468
return t
464469

465470
def visit_union_type(self, t: UnionType) -> Type:
471+
# Use cache to avoid O(n**2) or worse expansion of types during translation
472+
# (only for large unions, since caching adds overhead)
473+
use_cache = len(t.items) > 3
474+
if use_cache and (cached := self.get_cached(t)):
475+
return cached
476+
466477
expanded = self.expand_types(t.items)
467478
# After substituting for type variables in t.items, some resulting types
468479
# might be subtypes of others, however calling make_simplified_union()
@@ -475,7 +486,11 @@ def visit_union_type(self, t: UnionType) -> Type:
475486
# otherwise a single item union of a type alias will break it. Note this should not
476487
# cause infinite recursion since pathological aliases like A = Union[A, B] are
477488
# banned at the semantic analysis level.
478-
return get_proper_type(simplified)
489+
result = get_proper_type(simplified)
490+
491+
if use_cache:
492+
self.set_cached(t, result)
493+
return result
479494

480495
def visit_partial_type(self, t: PartialType) -> Type:
481496
return t

mypy/subtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
886886
if isinstance(right, Instance):
887887
return self._is_subtype(left.fallback, right)
888888
elif isinstance(right, TypedDictType):
889+
if left == right:
890+
return True # Fast path
889891
if not left.names_are_wider_than(right):
890892
return False
891893
for name, l, r in left.zip(right):

mypy/type_visitor.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,26 @@ class TypeTranslator(TypeVisitor[Type]):
181181
182182
Subclass this and override some methods to implement a non-trivial
183183
transformation.
184+
185+
We cache the results of certain translations to avoid
186+
massively expanding the sizes of types.
184187
"""
185188

189+
def __init__(self, cache: dict[Type, Type] | None = None) -> None:
190+
# For deduplication of results
191+
self.cache = cache
192+
193+
def get_cached(self, t: Type) -> Type | None:
194+
if self.cache is None:
195+
return None
196+
return self.cache.get(t)
197+
198+
def set_cached(self, orig: Type, new: Type) -> None:
199+
if self.cache is None:
200+
# Minor optimization: construct lazily
201+
self.cache = {}
202+
self.cache[orig] = new
203+
186204
def visit_unbound_type(self, t: UnboundType) -> Type:
187205
return t
188206

@@ -251,28 +269,42 @@ def visit_tuple_type(self, t: TupleType) -> Type:
251269
)
252270

253271
def visit_typeddict_type(self, t: TypedDictType) -> Type:
272+
# Use cache to avoid O(n**2) or worse expansion of types during translation
273+
if cached := self.get_cached(t):
274+
return cached
254275
items = {item_name: item_type.accept(self) for (item_name, item_type) in t.items.items()}
255-
return TypedDictType(
276+
result = TypedDictType(
256277
items,
257278
t.required_keys,
258279
# TODO: This appears to be unsafe.
259280
cast(Any, t.fallback.accept(self)),
260281
t.line,
261282
t.column,
262283
)
284+
self.set_cached(t, result)
285+
return result
263286

264287
def visit_literal_type(self, t: LiteralType) -> Type:
265288
fallback = t.fallback.accept(self)
266289
assert isinstance(fallback, Instance) # type: ignore[misc]
267290
return LiteralType(value=t.value, fallback=fallback, line=t.line, column=t.column)
268291

269292
def visit_union_type(self, t: UnionType) -> Type:
270-
return UnionType(
293+
# Use cache to avoid O(n**2) or worse expansion of types during translation
294+
# (only for large unions, since caching adds overhead)
295+
use_cache = len(t.items) > 3
296+
if use_cache and (cached := self.get_cached(t)):
297+
return cached
298+
299+
result = UnionType(
271300
self.translate_types(t.items),
272301
t.line,
273302
t.column,
274303
uses_pep604_syntax=t.uses_pep604_syntax,
275304
)
305+
if use_cache:
306+
self.set_cached(t, result)
307+
return result
276308

277309
def translate_types(self, types: Iterable[Type]) -> list[Type]:
278310
return [t.accept(self) for t in types]

mypy/typeanal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,6 +2271,7 @@ def __init__(
22712271
lookup: Callable[[str, Context], SymbolTableNode | None],
22722272
scope: TypeVarLikeScope,
22732273
) -> None:
2274+
super().__init__()
22742275
self.seen_nodes = seen_nodes
22752276
self.lookup = lookup
22762277
self.scope = scope
@@ -2660,6 +2661,7 @@ class TypeVarDefaultTranslator(TrivialSyntheticTypeTranslator):
26602661
def __init__(
26612662
self, api: SemanticAnalyzerInterface, tvar_expr_name: str, context: Context
26622663
) -> None:
2664+
super().__init__()
26632665
self.api = api
26642666
self.tvar_expr_name = tvar_expr_name
26652667
self.context = context

mypy/types.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def _expand_once(self) -> Type:
357357

358358
def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
359359
# Private method mostly for debugging and testing.
360-
unroller = UnrollAliasVisitor(set())
360+
unroller = UnrollAliasVisitor(set(), {})
361361
if nothing_args:
362362
alias = self.copy_modified(args=[UninhabitedType()] * len(self.args))
363363
else:
@@ -2586,7 +2586,8 @@ def __hash__(self) -> int:
25862586
def __eq__(self, other: object) -> bool:
25872587
if not isinstance(other, TypedDictType):
25882588
return NotImplemented
2589-
2589+
if self is other:
2590+
return True
25902591
return (
25912592
frozenset(self.items.keys()) == frozenset(other.items.keys())
25922593
and all(
@@ -3507,7 +3508,11 @@ def visit_type_list(self, t: TypeList) -> Type:
35073508

35083509

35093510
class UnrollAliasVisitor(TrivialSyntheticTypeTranslator):
3510-
def __init__(self, initial_aliases: set[TypeAliasType]) -> None:
3511+
def __init__(
3512+
self, initial_aliases: set[TypeAliasType], cache: dict[Type, Type] | None
3513+
) -> None:
3514+
assert cache is not None
3515+
super().__init__(cache)
35113516
self.recursed = False
35123517
self.initial_aliases = initial_aliases
35133518

@@ -3519,7 +3524,7 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
35193524
# A = Tuple[B, B]
35203525
# B = int
35213526
# will not be detected as recursive on the second encounter of B.
3522-
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t})
3527+
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}, self.cache)
35233528
result = get_proper_type(t).accept(subvisitor)
35243529
if subvisitor.recursed:
35253530
self.recursed = True

0 commit comments

Comments
 (0)