Skip to content

Commit 774b95c

Browse files
committed
more enum-related speedups
As a followup to #9394 address a few more O(n**2) behaviors caused by decomposing enums into unions of literals.
1 parent 0df8cf5 commit 774b95c

File tree

4 files changed

+101
-16
lines changed

4 files changed

+101
-16
lines changed

mypy/meet.py

+29
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
6464
if isinstance(declared, UnionType):
6565
return make_simplified_union([narrow_declared_type(x, narrowed)
6666
for x in declared.relevant_items()])
67+
if is_enum_overlapping_union(declared, narrowed):
68+
return narrowed
6769
elif not is_overlapping_types(declared, narrowed,
6870
prohibit_none_typevar_overlap=True):
6971
if state.strict_optional:
@@ -137,6 +139,21 @@ def get_possible_variants(typ: Type) -> List[Type]:
137139
return [typ]
138140

139141

142+
def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
143+
return (
144+
isinstance(x, Instance) and x.type.is_enum and
145+
isinstance(y, UnionType) and
146+
all(x.type == p.fallback.type
147+
for p in (get_proper_type(z) for z in y.relevant_items())
148+
if isinstance(p, LiteralType))
149+
)
150+
151+
152+
def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
153+
return (isinstance(x, LiteralType) and isinstance(y, UnionType) and
154+
any(x == get_proper_type(z) for z in y.items))
155+
156+
140157
def is_overlapping_types(left: Type,
141158
right: Type,
142159
ignore_promotions: bool = False,
@@ -198,6 +215,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
198215
#
199216
# These checks will also handle the NoneType and UninhabitedType cases for us.
200217

218+
# enums are sometimes expanded into an Union of Literals
219+
# when that happens we want to make sure we treat the two as overlapping
220+
# and crucially, we want to do that *fast* in case the enum is large
221+
# so we do it before expanding variants below to avoid O(n**2) behavior
222+
if (
223+
is_enum_overlapping_union(left, right)
224+
or is_enum_overlapping_union(right, left)
225+
or is_literal_in_union(left, right)
226+
or is_literal_in_union(right, left)
227+
):
228+
return True
229+
201230
if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
202231
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
203232
return True

mypy/sametypes.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Sequence
1+
from typing import Sequence, Tuple, Set, List
22

33
from mypy.types import (
44
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
55
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
66
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
77
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType
88
)
9-
from mypy.typeops import tuple_fallback, make_simplified_union
9+
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal
1010

1111

1212
def is_same_type(left: Type, right: Type) -> bool:
@@ -153,14 +153,32 @@ def visit_literal_type(self, left: LiteralType) -> bool:
153153

154154
def visit_union_type(self, left: UnionType) -> bool:
155155
if isinstance(self.right, UnionType):
156+
# fast path for simple literals
157+
def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]:
158+
lit: Set[Type] = set()
159+
rem: List[Type] = []
160+
for i in u.relevant_items():
161+
i = get_proper_type(i)
162+
if is_simple_literal(i):
163+
lit.add(i)
164+
else:
165+
rem.append(i)
166+
return lit, rem
167+
168+
left_lit, left_rem = _extract_literals(left)
169+
right_lit, right_rem = _extract_literals(self.right)
170+
171+
if left_lit != right_lit:
172+
return False
173+
156174
# Check that everything in left is in right
157-
for left_item in left.items:
158-
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
175+
for left_item in left_rem:
176+
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
159177
return False
160178

161179
# Check that everything in right is in left
162-
for right_item in self.right.items:
163-
if not any(is_same_type(right_item, left_item) for left_item in left.items):
180+
for right_item in right_rem:
181+
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
164182
return False
165183

166184
return True

mypy/subtypes.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535535
return False
536536

537537
def visit_union_type(self, left: UnionType) -> bool:
538+
if isinstance(self.right, Instance):
539+
literal_types: Set[Instance] = set()
540+
# avoid redundant check for union of literals
541+
for item in left.relevant_items():
542+
item = get_proper_type(item)
543+
lit_type = mypy.typeops.simple_literal_type(item)
544+
if lit_type is not None:
545+
if lit_type in literal_types:
546+
continue
547+
literal_types.add(lit_type)
548+
item = lit_type
549+
if not self._is_subtype(item, self.orig_right):
550+
return False
551+
return True
538552
return all(self._is_subtype(item, self.orig_right) for item in left.items)
539553

540554
def visit_partial_type(self, left: PartialType) -> bool:
@@ -1199,6 +1213,18 @@ def report(*args: Any) -> None:
11991213
return applied
12001214

12011215

1216+
def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
1217+
"""Helper function for restrict_subtype_away, allowing a fast path for Union of simple literals"""
1218+
new_items: List[Type] = []
1219+
for i in t.relevant_items():
1220+
it = get_proper_type(i)
1221+
if not mypy.typeops.is_simple_literal(it):
1222+
return None
1223+
if it != s:
1224+
new_items.append(i)
1225+
return new_items
1226+
1227+
12021228
def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
12031229
"""Return t minus s for runtime type assertions.
12041230
@@ -1212,10 +1238,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
12121238
s = get_proper_type(s)
12131239

12141240
if isinstance(t, UnionType):
1215-
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1216-
for item in t.relevant_items()
1217-
if (isinstance(get_proper_type(item), AnyType) or
1218-
not covers_at_runtime(item, s, ignore_promotions))]
1241+
new_items = try_restrict_literal_union(t, s) if isinstance(s, LiteralType) else []
1242+
new_items = new_items or [
1243+
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1244+
for item in t.relevant_items()
1245+
if (isinstance(get_proper_type(item), AnyType) or
1246+
not covers_at_runtime(item, s, ignore_promotions))
1247+
]
12191248
return UnionType.make_union(new_items)
12201249
elif covers_at_runtime(t, s, ignore_promotions):
12211250
return UninhabitedType()
@@ -1285,11 +1314,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
12851314
right = get_proper_type(right)
12861315

12871316
if isinstance(right, UnionType) and not isinstance(left, UnionType):
1288-
return any([is_proper_subtype(orig_left, item,
1289-
ignore_promotions=ignore_promotions,
1290-
erase_instances=erase_instances,
1291-
keep_erased_types=keep_erased_types)
1292-
for item in right.items])
1317+
return any(is_proper_subtype(orig_left, item,
1318+
ignore_promotions=ignore_promotions,
1319+
erase_instances=erase_instances,
1320+
keep_erased_types=keep_erased_types)
1321+
for item in right.items)
12931322
return left.accept(ProperSubtypeVisitor(orig_right,
12941323
ignore_promotions=ignore_promotions,
12951324
erase_instances=erase_instances,
@@ -1495,7 +1524,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
14951524
return False
14961525

14971526
def visit_union_type(self, left: UnionType) -> bool:
1498-
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
1527+
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)
14991528

15001529
def visit_partial_type(self, left: PartialType) -> bool:
15011530
# TODO: What's the right thing to do here?

mypy/typeops.py

+9
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
318318
return None
319319

320320

321+
def simple_literal_type(t: ProperType) -> Optional[Instance]:
322+
"""Extract the underlying fallback Instance type for a simple Literal"""
323+
if isinstance(t, Instance) and t.last_known_value is not None:
324+
t = t.last_known_value
325+
if isinstance(t, LiteralType):
326+
return t.fallback
327+
return None
328+
329+
321330
def is_simple_literal(t: ProperType) -> bool:
322331
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
323332
if isinstance(t, LiteralType):

0 commit comments

Comments
 (0)