Skip to content

Commit 9ca3035

Browse files
authored
Fix strict equality check if operand item type has custom __eq__ (#14513)
Don't complain about comparing lists, variable-length tuples or sets if one of the operands has an item type with a custom `__eq__` method. Fix #14511.
1 parent 4de3f5d commit 9ca3035

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

mypy/checkexpr.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,20 +2988,14 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29882988
# testCustomEqCheckStrictEquality for an example.
29892989
if not w.has_new_errors() and operator in ("==", "!="):
29902990
right_type = self.accept(right)
2991-
# We suppress the error if there is a custom __eq__() method on either
2992-
# side. User defined (or even standard library) classes can define this
2993-
# to return True for comparisons between non-overlapping types.
2994-
if not custom_special_method(
2995-
left_type, "__eq__"
2996-
) and not custom_special_method(right_type, "__eq__"):
2997-
# Also flag non-overlapping literals in situations like:
2998-
# x: Literal['a', 'b']
2999-
# if x == 'c':
3000-
# ...
3001-
left_type = try_getting_literal(left_type)
3002-
right_type = try_getting_literal(right_type)
3003-
if self.dangerous_comparison(left_type, right_type):
3004-
self.msg.dangerous_comparison(left_type, right_type, "equality", e)
2991+
# Also flag non-overlapping literals in situations like:
2992+
# x: Literal['a', 'b']
2993+
# if x == 'c':
2994+
# ...
2995+
left_type = try_getting_literal(left_type)
2996+
right_type = try_getting_literal(right_type)
2997+
if self.dangerous_comparison(left_type, right_type):
2998+
self.msg.dangerous_comparison(left_type, right_type, "equality", e)
30052999

30063000
elif operator == "is" or operator == "is not":
30073001
right_type = self.accept(right) # validate the right operand
@@ -3064,6 +3058,12 @@ def dangerous_comparison(
30643058

30653059
left, right = get_proper_types((left, right))
30663060

3061+
# We suppress the error if there is a custom __eq__() method on either
3062+
# side. User defined (or even standard library) classes can define this
3063+
# to return True for comparisons between non-overlapping types.
3064+
if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"):
3065+
return False
3066+
30673067
if self.chk.binder.is_unreachable_warning_suppressed():
30683068
# We are inside a function that contains type variables with value restrictions in
30693069
# its signature. In this case we just suppress all strict-equality checks to avoid
@@ -3094,14 +3094,18 @@ def dangerous_comparison(
30943094
return False
30953095
if isinstance(left, Instance) and isinstance(right, Instance):
30963096
# Special case some builtin implementations of AbstractSet.
3097+
left_name = left.type.fullname
3098+
right_name = right.type.fullname
30973099
if (
3098-
left.type.fullname in OVERLAPPING_TYPES_ALLOWLIST
3099-
and right.type.fullname in OVERLAPPING_TYPES_ALLOWLIST
3100+
left_name in OVERLAPPING_TYPES_ALLOWLIST
3101+
and right_name in OVERLAPPING_TYPES_ALLOWLIST
31003102
):
31013103
abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet")
31023104
left = map_instance_to_supertype(left, abstract_set)
31033105
right = map_instance_to_supertype(right, abstract_set)
3104-
return not is_overlapping_types(left.args[0], right.args[0])
3106+
return self.dangerous_comparison(left.args[0], right.args[0])
3107+
elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name:
3108+
return self.dangerous_comparison(left.args[0], right.args[0])
31053109
if isinstance(left, LiteralType) and isinstance(right, LiteralType):
31063110
if isinstance(left.value, bool) and isinstance(right.value, bool):
31073111
# Comparing different booleans is not dangerous.

test-data/unit/check-expressions.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,24 @@ class B:
19851985
A() == B() # E: Unsupported operand types for == ("A" and "B")
19861986
[builtins fixtures/bool.pyi]
19871987

1988+
[case testStrictEqualitySequenceAndCustomEq]
1989+
# flags: --strict-equality
1990+
from typing import Tuple
1991+
1992+
class C: pass
1993+
class D:
1994+
def __eq__(self, other): return True
1995+
1996+
a = [C()]
1997+
b = [D()]
1998+
a == b
1999+
b == a
2000+
t1: Tuple[C, ...]
2001+
t2: Tuple[D, ...]
2002+
t1 == t2
2003+
t2 == t1
2004+
[builtins fixtures/bool.pyi]
2005+
19882006
[case testCustomEqCheckStrictEqualityOKInstance]
19892007
# flags: --strict-equality
19902008
class A:

test-data/unit/fixtures/bool.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ class float: pass
1616
class str: pass
1717
class unicode: pass
1818
class ellipsis: pass
19-
class list: pass
19+
class list(Generic[T]): pass
2020
class property: pass

test-data/unit/fixtures/set.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ T = TypeVar('T')
66

77
class object:
88
def __init__(self) -> None: pass
9+
def __eq__(self, other): pass
910

1011
class type: pass
1112
class tuple(Generic[T]): pass

0 commit comments

Comments
 (0)