diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c2cf226ef210..8dea7d0e8551 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2988,20 +2988,14 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # testCustomEqCheckStrictEquality for an example. if not w.has_new_errors() and operator in ("==", "!="): right_type = self.accept(right) - # We suppress the error if there is a custom __eq__() method on either - # side. User defined (or even standard library) classes can define this - # to return True for comparisons between non-overlapping types. - if not custom_special_method( - left_type, "__eq__" - ) and not custom_special_method(right_type, "__eq__"): - # Also flag non-overlapping literals in situations like: - # x: Literal['a', 'b'] - # if x == 'c': - # ... - left_type = try_getting_literal(left_type) - right_type = try_getting_literal(right_type) - if self.dangerous_comparison(left_type, right_type): - self.msg.dangerous_comparison(left_type, right_type, "equality", e) + # Also flag non-overlapping literals in situations like: + # x: Literal['a', 'b'] + # if x == 'c': + # ... + left_type = try_getting_literal(left_type) + right_type = try_getting_literal(right_type) + if self.dangerous_comparison(left_type, right_type): + self.msg.dangerous_comparison(left_type, right_type, "equality", e) elif operator == "is" or operator == "is not": right_type = self.accept(right) # validate the right operand @@ -3064,6 +3058,12 @@ def dangerous_comparison( left, right = get_proper_types((left, right)) + # We suppress the error if there is a custom __eq__() method on either + # side. User defined (or even standard library) classes can define this + # to return True for comparisons between non-overlapping types. + if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"): + return False + if self.chk.binder.is_unreachable_warning_suppressed(): # We are inside a function that contains type variables with value restrictions in # its signature. In this case we just suppress all strict-equality checks to avoid @@ -3094,14 +3094,18 @@ def dangerous_comparison( return False if isinstance(left, Instance) and isinstance(right, Instance): # Special case some builtin implementations of AbstractSet. + left_name = left.type.fullname + right_name = right.type.fullname if ( - left.type.fullname in OVERLAPPING_TYPES_ALLOWLIST - and right.type.fullname in OVERLAPPING_TYPES_ALLOWLIST + left_name in OVERLAPPING_TYPES_ALLOWLIST + and right_name in OVERLAPPING_TYPES_ALLOWLIST ): abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet") left = map_instance_to_supertype(left, abstract_set) right = map_instance_to_supertype(right, abstract_set) - return not is_overlapping_types(left.args[0], right.args[0]) + return self.dangerous_comparison(left.args[0], right.args[0]) + elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name: + return self.dangerous_comparison(left.args[0], right.args[0]) if isinstance(left, LiteralType) and isinstance(right, LiteralType): if isinstance(left.value, bool) and isinstance(right.value, bool): # Comparing different booleans is not dangerous. diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 78ef78e9ad98..20ccbb17d5d5 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -1985,6 +1985,24 @@ class B: A() == B() # E: Unsupported operand types for == ("A" and "B") [builtins fixtures/bool.pyi] +[case testStrictEqualitySequenceAndCustomEq] +# flags: --strict-equality +from typing import Tuple + +class C: pass +class D: + def __eq__(self, other): return True + +a = [C()] +b = [D()] +a == b +b == a +t1: Tuple[C, ...] +t2: Tuple[D, ...] +t1 == t2 +t2 == t1 +[builtins fixtures/bool.pyi] + [case testCustomEqCheckStrictEqualityOKInstance] # flags: --strict-equality class A: diff --git a/test-data/unit/fixtures/bool.pyi b/test-data/unit/fixtures/bool.pyi index 245526d78907..c48efcbd7269 100644 --- a/test-data/unit/fixtures/bool.pyi +++ b/test-data/unit/fixtures/bool.pyi @@ -16,5 +16,5 @@ class float: pass class str: pass class unicode: pass class ellipsis: pass -class list: pass +class list(Generic[T]): pass class property: pass diff --git a/test-data/unit/fixtures/set.pyi b/test-data/unit/fixtures/set.pyi index 9852bbc9fcc6..d397d4f54af2 100644 --- a/test-data/unit/fixtures/set.pyi +++ b/test-data/unit/fixtures/set.pyi @@ -6,6 +6,7 @@ T = TypeVar('T') class object: def __init__(self) -> None: pass + def __eq__(self, other): pass class type: pass class tuple(Generic[T]): pass