Skip to content

Commit 87f1571

Browse files
committed
Make reachability code understand chained comparisons
Currently, our reachability code does not understand how to parse comparisons like `a == b == c`: the `find_isinstance_check` method only attempts to analyze comparisons that contain a single `==`, `is`, or `in` operator. This pull request generalizes that logic so we can support any arbitrary number of comparisons. It also along the way unifies the logic we have for handling `is` and `==` checks: the latter check is now just treated a weaker variation of the former. (Expressions containing `==` may do arbitrary things if the underlying operands contain custom `__eq__` methods.) As a side-effect, this PR adds support for the following: x: Optional[str] if x is 'some-string': # Previously, the revealed type would be Union[str, None] # Now, the revealed type is just 'str' reveal_type(x) else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' We previously supported this narrowing logic when doing equality checks (e.g. when doing `if x == 'some-string'`). As a second side-effect, this PR adds support for the following: class Foo(Enum): A = 1 B = 2 y: Foo if y == Foo.A: reveal_type(y) # N: Revealed type is 'Literal[Foo.A]' else: reveal_type(y) # N: Revealed type is 'Literal[Foo.B]' We previously supported this kind of narrowing only when doing identity checks (e.g. `if y is Foo.A`). To avoid any bad interactions with custom `__eq__` methods, we enable this narrowing check only if both operands do not define custom `__eq__` methods.
1 parent 5b70ff5 commit 87f1571

File tree

5 files changed

+298
-59
lines changed

5 files changed

+298
-59
lines changed

mypy/checker.py

+150-58
Original file line numberDiff line numberDiff line change
@@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression
35363536
vartype = type_map[expr]
35373537
return self.conditional_callable_type_map(expr, vartype)
35383538
elif isinstance(node, ComparisonExpr):
3539-
operand_types = [coerce_to_literal(type_map[expr])
3540-
for expr in node.operands if expr in type_map]
3541-
3542-
is_not = node.operators == ['is not']
3543-
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
3544-
if_vars = {} # type: TypeMap
3545-
else_vars = {} # type: TypeMap
3546-
3547-
for i, expr in enumerate(node.operands):
3548-
var_type = operand_types[i]
3549-
other_type = operand_types[1 - i]
3550-
3551-
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
3552-
# This should only be true at most once: there should be
3553-
# exactly two elements in node.operands and if the 'other type' is
3554-
# a singleton type, it by definition does not need to be narrowed:
3555-
# it already has the most precise type possible so does not need to
3556-
# be narrowed/included in the output map.
3557-
#
3558-
# TODO: Generalize this to handle the case where 'other_type' is
3559-
# a union of singleton types.
3539+
operand_types = []
3540+
for expr in node.operands:
3541+
if expr not in type_map:
3542+
return {}, {}
3543+
operand_types.append(coerce_to_literal(type_map[expr]))
3544+
3545+
type_maps = []
3546+
for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()):
3547+
left_type = operand_types[i]
3548+
right_type = operand_types[i + 1]
3549+
3550+
if_map = {} # type: TypeMap
3551+
else_map = {} # type: TypeMap
3552+
if operator in {'in', 'not in'}:
3553+
right_item_type = builtin_item_type(right_type)
3554+
if right_item_type is None or is_optional(right_item_type):
3555+
continue
3556+
if (isinstance(right_item_type, Instance)
3557+
and right_item_type.type.fullname() == 'builtins.object'):
3558+
continue
3559+
3560+
if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE
3561+
and not is_literal_none(left_expr) and
3562+
is_overlapping_erased_types(left_type, right_item_type)):
3563+
if_map, else_map = {left_expr: remove_optional(left_type)}, {}
3564+
else:
3565+
continue
3566+
elif operator in {'==', '!='}:
3567+
if_map, else_map = self.narrow_given_equality(
3568+
left_expr, left_type, right_expr, right_type, assume_identity=False)
3569+
elif operator in {'is', 'is not'}:
3570+
if_map, else_map = self.narrow_given_equality(
3571+
left_expr, left_type, right_expr, right_type, assume_identity=True)
3572+
else:
3573+
continue
35603574

3561-
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3562-
fallback_name = other_type.fallback.type.fullname()
3563-
var_type = try_expanding_enum_to_union(var_type, fallback_name)
3575+
if operator in {'not in', '!=', 'is not'}:
3576+
if_map, else_map = else_map, if_map
35643577

3565-
target_type = [TypeRange(other_type, is_upper_bound=False)]
3566-
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
3567-
break
3578+
type_maps.append((if_map, else_map))
35683579

3569-
if is_not:
3570-
if_vars, else_vars = else_vars, if_vars
3571-
return if_vars, else_vars
3572-
# Check for `x == y` where x is of type Optional[T] and y is of type T
3573-
# or a type that overlaps with T (or vice versa).
3574-
elif node.operators == ['==']:
3575-
first_type = type_map[node.operands[0]]
3576-
second_type = type_map[node.operands[1]]
3577-
if is_optional(first_type) != is_optional(second_type):
3578-
if is_optional(first_type):
3579-
optional_type, comp_type = first_type, second_type
3580-
optional_expr = node.operands[0]
3581-
else:
3582-
optional_type, comp_type = second_type, first_type
3583-
optional_expr = node.operands[1]
3584-
if is_overlapping_erased_types(optional_type, comp_type):
3585-
return {optional_expr: remove_optional(optional_type)}, {}
3586-
elif node.operators in [['in'], ['not in']]:
3587-
expr = node.operands[0]
3588-
left_type = type_map[expr]
3589-
right_type = builtin_item_type(type_map[node.operands[1]])
3590-
right_ok = right_type and (not is_optional(right_type) and
3591-
(not isinstance(right_type, Instance) or
3592-
right_type.type.fullname() != 'builtins.object'))
3593-
if (right_type and right_ok and is_optional(left_type) and
3594-
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
3595-
is_overlapping_erased_types(left_type, right_type)):
3596-
if node.operators == ['in']:
3597-
return {expr: remove_optional(left_type)}, {}
3598-
if node.operators == ['not in']:
3599-
return {}, {expr: remove_optional(left_type)}
3580+
if len(type_maps) == 0:
3581+
return {}, {}
3582+
elif len(type_maps) == 1:
3583+
return type_maps[0]
3584+
else:
3585+
# Comparisons like 'a == b == c is d' is the same thing as
3586+
# '(a == b) and (b == c) and (c is d)'. So after generating each
3587+
# individual comparison's typemaps, we "and" them together here.
3588+
# (Also see comments below where we handle the 'and' OpExpr.)
3589+
final_if_map, final_else_map = type_maps[0]
3590+
for if_map, else_map in type_maps[1:]:
3591+
final_if_map = and_conditional_maps(final_if_map, if_map)
3592+
final_else_map = or_conditional_maps(final_else_map, else_map)
3593+
return final_if_map, final_else_map
36003594
elif isinstance(node, RefExpr):
36013595
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
36023596
# respectively
@@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression
36303624
# Not a supported isinstance check
36313625
return {}, {}
36323626

3627+
def narrow_given_equality(self,
3628+
left_expr: Expression,
3629+
left_type: Type,
3630+
right_expr: Expression,
3631+
right_type: Type,
3632+
assume_identity: bool,
3633+
) -> Tuple[TypeMap, TypeMap]:
3634+
"""Assuming that the given 'left' and 'right' exprs are equal to each other, try
3635+
producing TypeMaps refining the types of either the left or right exprs (or neither,
3636+
if we can't learn anything from the comparison).
3637+
3638+
For more details about what TypeMaps are, see the docstring in find_isinstance_check.
3639+
3640+
If 'assume_identity' is true, assume that this comparison was done using an
3641+
identity comparison (left_expr is right_expr), not just an equality comparison
3642+
(left_expr == right_expr). Identity checks are not overridable, so we can infer
3643+
more information in that case.
3644+
"""
3645+
3646+
# For the sake of simplicity, we currently attempt inferring a more precise type
3647+
# for just one of the two variables.
3648+
comparisons = [
3649+
(left_expr, left_type, right_type),
3650+
(right_expr, right_type, left_type),
3651+
]
3652+
3653+
for expr, expr_type, other_type in comparisons:
3654+
# The 'expr' isn't an expression that we can refine the type of. Skip
3655+
# attempting to refine this expr.
3656+
if literal(expr) != LITERAL_TYPE:
3657+
continue
3658+
3659+
# Case 1: If the 'other_type' is a singleton (only one value has
3660+
# the specified type), attempt to narrow 'expr_type' to just that
3661+
# singleton type.
3662+
if is_singleton_type(other_type):
3663+
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
3664+
if not assume_identity:
3665+
# Our checks need to be more conservative if the operand is
3666+
# '==' or '!=': all bets are off if either of the two operands
3667+
# has a custom `__eq__` or `__ne__` method.
3668+
#
3669+
# So, we permit this check to succeed only if 'other_type' does
3670+
# not define custom equality logic
3671+
if not uses_default_equality_checks(expr_type):
3672+
continue
3673+
if not uses_default_equality_checks(other_type.fallback):
3674+
continue
3675+
fallback_name = other_type.fallback.type.fullname()
3676+
expr_type = try_expanding_enum_to_union(expr_type, fallback_name)
3677+
3678+
target_type = [TypeRange(other_type, is_upper_bound=False)]
3679+
return conditional_type_map(expr, expr_type, target_type)
3680+
3681+
# Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'.
3682+
#
3683+
# Note: This check is actually strictly speaking unsafe: stripping away the 'None'
3684+
# would be unsound in the case where A defines an '__eq__' method that always
3685+
# returns 'True', for example.
3686+
#
3687+
# We implement this check partly for backwards-compatibility reasons and partly
3688+
# because those kinds of degenerate '__eq__' implementations are probably rare
3689+
# enough that this is fine in practice.
3690+
#
3691+
# We could also probably generalize this block to strip away *any* singleton type,
3692+
# if we were fine with a bit more unsoundness.
3693+
if is_optional(expr_type) and not is_optional(other_type):
3694+
if is_overlapping_erased_types(expr_type, other_type):
3695+
return {expr: remove_optional(expr_type)}, {}
3696+
3697+
return {}, {}
3698+
36333699
#
36343700
# Helpers
36353701
#
@@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool:
45054571
return node_name.startswith('__') and not node_name.endswith('__')
45064572

45074573

4574+
def uses_default_equality_checks(typ: Type) -> bool:
4575+
"""Returns 'true' if we know for certain that the given type is using
4576+
the default __eq__ and __ne__ checks defined in 'builtins.object'.
4577+
We can use this information to make more aggressive inferences when
4578+
analyzing things like equality checks.
4579+
4580+
When in doubt, this function will conservatively bias towards
4581+
returning False.
4582+
"""
4583+
if isinstance(typ, UnionType):
4584+
return all(map(uses_default_equality_checks, typ.items))
4585+
# TODO: Generalize this so it'll handle other types with fallbacks
4586+
if isinstance(typ, LiteralType):
4587+
typ = typ.fallback
4588+
if isinstance(typ, Instance):
4589+
typeinfo = typ.type
4590+
eq_sym = typeinfo.get('__eq__')
4591+
ne_sym = typeinfo.get('__ne__')
4592+
if eq_sym is None or ne_sym is None:
4593+
return False
4594+
return (eq_sym.fullname == 'builtins.object.__eq__'
4595+
and ne_sym.fullname == 'builtins.object.__ne__')
4596+
else:
4597+
return False
4598+
4599+
45084600
def is_singleton_type(typ: Type) -> bool:
45094601
"""Returns 'true' if this type is a "singleton type" -- if there exists
45104602
exactly only one runtime value associated with this type.

mypy/nodes.py

+8
Original file line numberDiff line numberDiff line change
@@ -1718,13 +1718,21 @@ class ComparisonExpr(Expression):
17181718

17191719
def __init__(self, operators: List[str], operands: List[Expression]) -> None:
17201720
super().__init__()
1721+
assert len(operators) + 1 == len(operands)
17211722
self.operators = operators
17221723
self.operands = operands
17231724
self.method_types = []
17241725

17251726
def accept(self, visitor: ExpressionVisitor[T]) -> T:
17261727
return visitor.visit_comparison_expr(self)
17271728

1729+
def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]:
1730+
"""If this comparison expr is "a < b is c == d", yields the sequence
1731+
("<", a, b), ("is", b, c), ("==", c, d)
1732+
"""
1733+
for i, operator in enumerate(self.operators):
1734+
yield operator, self.operands[i], self.operands[i + 1]
1735+
17281736

17291737
class SliceExpr(Expression):
17301738
"""Slice expression (e.g. 'x:y', 'x:', '::2' or ':').

test-data/unit/check-enum.test

+127-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int'
611611
[out2]
612612
main:2: note: Revealed type is 'builtins.str'
613613

614-
[case testEnumReachabilityChecksBasic]
614+
[case testEnumReachabilityChecksBasicIdentity]
615615
from enum import Enum
616616
from typing_extensions import Literal
617617

@@ -659,6 +659,54 @@ else:
659659
reveal_type(y) # No output here: this branch is unreachable
660660
[builtins fixtures/bool.pyi]
661661

662+
[case testEnumReachabilityChecksBasicEquality]
663+
from enum import Enum
664+
from typing_extensions import Literal
665+
666+
class Foo(Enum):
667+
A = 1
668+
B = 2
669+
C = 3
670+
671+
x: Literal[Foo.A, Foo.B, Foo.C]
672+
if x == Foo.A:
673+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
674+
elif x == Foo.B:
675+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
676+
elif x == Foo.C:
677+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
678+
else:
679+
reveal_type(x) # No output here: this branch is unreachable
680+
681+
if Foo.A == x:
682+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
683+
elif Foo.B == x:
684+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
685+
elif Foo.C == x:
686+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
687+
else:
688+
reveal_type(x) # No output here: this branch is unreachable
689+
690+
y: Foo
691+
if y == Foo.A:
692+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
693+
elif y == Foo.B:
694+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
695+
elif y == Foo.C:
696+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
697+
else:
698+
reveal_type(y) # No output here: this branch is unreachable
699+
700+
if Foo.A == y:
701+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
702+
elif Foo.B == y:
703+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
704+
elif Foo.C == y:
705+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
706+
else:
707+
reveal_type(y) # No output here: this branch is unreachable
708+
[builtins fixtures/bool.pyi]
709+
662710
[case testEnumReachabilityChecksIndirect]
663711
from enum import Enum
664712
from typing_extensions import Literal, Final
@@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str:
854902
return 'PROCESSED: ' + response
855903

856904
[builtins fixtures/primitives.pyi]
905+
906+
[case testEnumReachabilityDisabledGivenCustomEquality]
907+
from typing import Union
908+
from enum import Enum
909+
910+
class Parent(Enum):
911+
def __ne__(self, other: object) -> bool: return True
912+
913+
class Foo(Enum):
914+
A = 1
915+
B = 2
916+
def __eq__(self, other: object) -> bool: return True
917+
918+
class Bar(Parent):
919+
A = 1
920+
B = 2
921+
922+
class Ok(Enum):
923+
A = 1
924+
B = 2
925+
926+
x: Foo
927+
if x is Foo.A:
928+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
929+
else:
930+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
931+
932+
if x == Foo.A:
933+
reveal_type(x) # N: Revealed type is '__main__.Foo'
934+
else:
935+
reveal_type(x) # N: Revealed type is '__main__.Foo'
936+
937+
y: Bar
938+
if y is Bar.A:
939+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]'
940+
else:
941+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]'
942+
943+
if y == Bar.A:
944+
reveal_type(y) # N: Revealed type is '__main__.Bar'
945+
else:
946+
reveal_type(y) # N: Revealed type is '__main__.Bar'
947+
948+
z1: Union[Bar, Ok]
949+
if z1 is Ok.A:
950+
reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]'
951+
else:
952+
reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]'
953+
954+
z2: Union[Bar, Ok]
955+
if z2 == Ok.A:
956+
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
957+
else:
958+
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
959+
[builtins fixtures/primitives.pyi]
960+
961+
[case testEnumReachabilityWithChaining]
962+
from enum import Enum
963+
class Foo(Enum):
964+
A = 1
965+
B = 2
966+
967+
x: Foo
968+
y: Foo
969+
if x is Foo.A is y:
970+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
971+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
972+
else:
973+
reveal_type(x) # N: Revealed type is '__main__.Foo'
974+
reveal_type(y) # N: Revealed type is '__main__.Foo'
975+
976+
if x == Foo.A == y:
977+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
978+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
979+
else:
980+
reveal_type(x) # N: Revealed type is '__main__.Foo'
981+
reveal_type(y) # N: Revealed type is '__main__.Foo'
982+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)