@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535535 return False
536536
537537 def visit_union_type (self , left : UnionType ) -> bool :
538+ if isinstance (self .right , Instance ):
539+ literal_types : Set [Instance ] = set ()
540+ # avoid redundant check for union of literals
541+ for item in left .relevant_items ():
542+ item = get_proper_type (item )
543+ lit_type = mypy .typeops .simple_literal_type (item )
544+ if lit_type is not None :
545+ if lit_type in literal_types :
546+ continue
547+ literal_types .add (lit_type )
548+ item = lit_type
549+ if not self ._is_subtype (item , self .orig_right ):
550+ return False
551+ return True
538552 return all (self ._is_subtype (item , self .orig_right ) for item in left .items )
539553
540554 def visit_partial_type (self , left : PartialType ) -> bool :
@@ -1199,6 +1213,18 @@ def report(*args: Any) -> None:
11991213 return applied
12001214
12011215
1216+ def try_restrict_literal_union (t : UnionType , s : Type ) -> Optional [List [Type ]]:
1217+ """Helper function for restrict_subtype_away, allowing a fast path for Union of simple literals"""
1218+ new_items : List [Type ] = []
1219+ for i in t .relevant_items ():
1220+ it = get_proper_type (i )
1221+ if not mypy .typeops .is_simple_literal (it ):
1222+ return None
1223+ if it != s :
1224+ new_items .append (i )
1225+ return new_items
1226+
1227+
12021228def restrict_subtype_away (t : Type , s : Type , * , ignore_promotions : bool = False ) -> Type :
12031229 """Return t minus s for runtime type assertions.
12041230
@@ -1212,10 +1238,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
12121238 s = get_proper_type (s )
12131239
12141240 if isinstance (t , UnionType ):
1215- new_items = [restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1216- for item in t .relevant_items ()
1217- if (isinstance (get_proper_type (item ), AnyType ) or
1218- not covers_at_runtime (item , s , ignore_promotions ))]
1241+ new_items = try_restrict_literal_union (t , s ) if isinstance (s , LiteralType ) else []
1242+ new_items = new_items or [
1243+ restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1244+ for item in t .relevant_items ()
1245+ if (isinstance (get_proper_type (item ), AnyType ) or
1246+ not covers_at_runtime (item , s , ignore_promotions ))
1247+ ]
12191248 return UnionType .make_union (new_items )
12201249 elif covers_at_runtime (t , s , ignore_promotions ):
12211250 return UninhabitedType ()
@@ -1285,11 +1314,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
12851314 right = get_proper_type (right )
12861315
12871316 if isinstance (right , UnionType ) and not isinstance (left , UnionType ):
1288- return any ([ is_proper_subtype (orig_left , item ,
1289- ignore_promotions = ignore_promotions ,
1290- erase_instances = erase_instances ,
1291- keep_erased_types = keep_erased_types )
1292- for item in right .items ] )
1317+ return any (is_proper_subtype (orig_left , item ,
1318+ ignore_promotions = ignore_promotions ,
1319+ erase_instances = erase_instances ,
1320+ keep_erased_types = keep_erased_types )
1321+ for item in right .items )
12931322 return left .accept (ProperSubtypeVisitor (orig_right ,
12941323 ignore_promotions = ignore_promotions ,
12951324 erase_instances = erase_instances ,
@@ -1495,7 +1524,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
14951524 return False
14961525
14971526 def visit_union_type (self , left : UnionType ) -> bool :
1498- return all ([ self ._is_proper_subtype (item , self .orig_right ) for item in left .items ] )
1527+ return all (self ._is_proper_subtype (item , self .orig_right ) for item in left .items )
14991528
15001529 def visit_partial_type (self , left : PartialType ) -> bool :
15011530 # TODO: What's the right thing to do here?
0 commit comments