@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535
535
return False
536
536
537
537
def visit_union_type (self , left : UnionType ) -> bool :
538
+ if isinstance (self .right , Instance ):
539
+ literal_types : Set [Instance ] = set ()
540
+ # avoid redundant check for union of literals
541
+ for item in left .relevant_items ():
542
+ item = get_proper_type (item )
543
+ lit_type = mypy .typeops .simple_literal_type (item )
544
+ if lit_type is not None :
545
+ if lit_type in literal_types :
546
+ continue
547
+ literal_types .add (lit_type )
548
+ item = lit_type
549
+ if not self ._is_subtype (item , self .orig_right ):
550
+ return False
551
+ return True
538
552
return all (self ._is_subtype (item , self .orig_right ) for item in left .items )
539
553
540
554
def visit_partial_type (self , left : PartialType ) -> bool :
@@ -1199,6 +1213,18 @@ def report(*args: Any) -> None:
1199
1213
return applied
1200
1214
1201
1215
1216
+ def try_restrict_literal_union (t : UnionType , s : Type ) -> Optional [List [Type ]]:
1217
+ """Helper function for restrict_subtype_away, allowing a fast path for Union of simple literals"""
1218
+ new_items : List [Type ] = []
1219
+ for i in t .relevant_items ():
1220
+ it = get_proper_type (i )
1221
+ if not mypy .typeops .is_simple_literal (it ):
1222
+ return None
1223
+ if it != s :
1224
+ new_items .append (i )
1225
+ return new_items
1226
+
1227
+
1202
1228
def restrict_subtype_away (t : Type , s : Type , * , ignore_promotions : bool = False ) -> Type :
1203
1229
"""Return t minus s for runtime type assertions.
1204
1230
@@ -1212,10 +1238,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
1212
1238
s = get_proper_type (s )
1213
1239
1214
1240
if isinstance (t , UnionType ):
1215
- new_items = [restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1216
- for item in t .relevant_items ()
1217
- if (isinstance (get_proper_type (item ), AnyType ) or
1218
- not covers_at_runtime (item , s , ignore_promotions ))]
1241
+ new_items = try_restrict_literal_union (t , s ) if isinstance (s , LiteralType ) else []
1242
+ new_items = new_items or [
1243
+ restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1244
+ for item in t .relevant_items ()
1245
+ if (isinstance (get_proper_type (item ), AnyType ) or
1246
+ not covers_at_runtime (item , s , ignore_promotions ))
1247
+ ]
1219
1248
return UnionType .make_union (new_items )
1220
1249
elif covers_at_runtime (t , s , ignore_promotions ):
1221
1250
return UninhabitedType ()
@@ -1285,11 +1314,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
1285
1314
right = get_proper_type (right )
1286
1315
1287
1316
if isinstance (right , UnionType ) and not isinstance (left , UnionType ):
1288
- return any ([ is_proper_subtype (orig_left , item ,
1289
- ignore_promotions = ignore_promotions ,
1290
- erase_instances = erase_instances ,
1291
- keep_erased_types = keep_erased_types )
1292
- for item in right .items ] )
1317
+ return any (is_proper_subtype (orig_left , item ,
1318
+ ignore_promotions = ignore_promotions ,
1319
+ erase_instances = erase_instances ,
1320
+ keep_erased_types = keep_erased_types )
1321
+ for item in right .items )
1293
1322
return left .accept (ProperSubtypeVisitor (orig_right ,
1294
1323
ignore_promotions = ignore_promotions ,
1295
1324
erase_instances = erase_instances ,
@@ -1495,7 +1524,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
1495
1524
return False
1496
1525
1497
1526
def visit_union_type (self , left : UnionType ) -> bool :
1498
- return all ([ self ._is_proper_subtype (item , self .orig_right ) for item in left .items ] )
1527
+ return all (self ._is_proper_subtype (item , self .orig_right ) for item in left .items )
1499
1528
1500
1529
def visit_partial_type (self , left : PartialType ) -> bool :
1501
1530
# TODO: What's the right thing to do here?
0 commit comments