Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 69 additions & 60 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6625,11 +6625,6 @@ def narrow_type_by_identity_equality(
the output TypeMaps.

"""
# should_narrow_by_identity_equality:
# If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined
# custom __eq__ or __ne__ method
should_narrow_by_identity_equality: bool

# is_target_for_value_narrowing:
# If the operator returns True when compared to this target, do we narrow in else branch?
# E.g. if operator is "==", then:
Expand All @@ -6646,7 +6641,7 @@ def narrow_type_by_identity_equality(
if operator in {"is", "is not"}:
is_target_for_value_narrowing = is_singleton_identity_type
should_coerce_literals = True
should_narrow_by_identity_equality = True
custom_eq_indices = set()
enum_comparison_is_ambiguous = False

elif operator in {"==", "!="}:
Expand All @@ -6659,19 +6654,11 @@ def narrow_type_by_identity_equality(
should_coerce_literals = True
break

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
enum_comparison_is_ambiguous = True
else:
raise AssertionError

if not should_narrow_by_identity_equality:
# This is a bit of a legacy code path that might be a little unsound since it ignores
# custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
return self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_indices
)

value_targets = []
type_targets = []
for i in expr_indices:
Expand All @@ -6683,6 +6670,10 @@ def narrow_type_by_identity_equality(
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
expr_type = coerce_to_literal(expr_type)
if i in custom_eq_indices:
# We can't use types with custom __eq__ as targets for narrowing
# E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
continue
if is_target_for_value_narrowing(get_proper_type(expr_type)):
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
else:
Expand All @@ -6694,7 +6685,11 @@ def narrow_type_by_identity_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
expr_type = coerce_to_literal(operand_types[i])
if i in custom_eq_indices:
# Handled later
continue
expr_type = operand_types[i]
expr_type = coerce_to_literal(expr_type)
Comment on lines +6691 to +6692
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, why did these lines get added? Maybe they're right (not sure what it's doing...) but seems weird to add code to this existing codepath?

Copy link
Collaborator Author

@hauntsaninja hauntsaninja Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no semantic change in the expr_type stuff, the important change here is if i in custom_eq_indices: continue

As for why I have a stray line: managing a stack of like twenty commits is a little fiddly

expr_type = try_expanding_sum_type_to_union(expr_type, None)
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
for j, target in value_targets:
Expand All @@ -6715,6 +6710,9 @@ def narrow_type_by_identity_equality(
for i in expr_indices:
if i not in narrowable_indices:
continue
if i in custom_eq_indices:
# Handled later
continue
expr_type = operand_types[i]
for j, target in type_targets:
if i == j:
Expand All @@ -6723,9 +6721,63 @@ def narrow_type_by_identity_equality(
operands[i], *conditional_types(expr_type, [target])
)
if if_map:
else_map = {} # this is the big difference compared to the above
# For type_targets, we cannot narrow in the negative case
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
else_map = {}
Comment on lines +6724 to +6726
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the diff also seems unrelated?

partial_type_maps.append((if_map, else_map))

for i in custom_eq_indices:
Copy link
Collaborator Author

@hauntsaninja hauntsaninja Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a later PR that condenses this a little bit / adds more comments

Maybe worth looking at the version on my dev branch: https://github.com/hauntsaninja/mypy/pull/5/files#diff-f96a2d6138bc6cdf2a07c4d37f6071cc25c1631afc107e277a28d5b59fc0ef04R6699

if i not in narrowable_indices:
continue
union_expr_type = get_proper_type(operand_types[i])
if not isinstance(union_expr_type, UnionType):
expr_type = operand_types[i]
for j, target in value_targets:
_if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if else_map:
partial_type_maps.append(({}, else_map))
continue

or_if_maps: list[TypeMap] = []
or_else_maps: list[TypeMap] = []
for expr_type in union_expr_type.items:
if has_custom_eq_checks(expr_type):
or_if_maps.append({operands[i]: expr_type})
Copy link
Collaborator

@A5rocks A5rocks Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-snip-

Nevermind, I see what this does.


for j in expr_indices:
if j in custom_eq_indices:
continue
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target = TypeRange(target_type, is_upper_bound=False)
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))

if is_value_target:
expr_type = coerce_to_literal(expr_type)
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target], default=expr_type)
)
or_if_maps.append(if_map)
if is_value_target:
or_else_maps.append(else_map)

final_if_map: TypeMap = {}
final_else_map: TypeMap = {}
if or_if_maps:
final_if_map = or_if_maps[0]
for if_map in or_if_maps[1:]:
final_if_map = or_conditional_maps(final_if_map, if_map)
if or_else_maps:
final_else_map = or_else_maps[0]
for else_map in or_else_maps[1:]:
final_else_map = or_conditional_maps(final_else_map, else_map)

partial_type_maps.append((final_if_map, final_else_map))

for i in expr_indices:
type_expr = operands[i]
if (
Expand Down Expand Up @@ -6943,49 +6995,6 @@ def _propagate_walrus_assignments(
return parent_expr
return expr

def refine_away_none_in_comparison(
self,
operands: list[Expression],
operand_types: list[Type],
chain_indices: list[int],
narrowable_operand_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining away None in an identity/equality chain.

For more details about what the different arguments mean, see the
docstring of 'narrow_type_by_identity_equality' up above.
"""

non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_overlapping_none(typ):
non_optional_types.append(typ)

if_map, else_map = {}, {}

if not non_optional_types or (len(non_optional_types) != len(chain_indices)):

# Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
# convenient but is strictly not type-safe):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)

# Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
# so type-safe but less convenient, because e.g. `Optional[A] == None` still results
# in `Optional[A]`):
if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types):
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if is_overlapping_none(expr_type):
else_map[operands[i]] = remove_optional(expr_type)

return if_map, else_map

def is_len_of_tuple(self, expr: Expression) -> bool:
"""Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
if not isinstance(expr, CallExpr):
Expand Down
115 changes: 93 additions & 22 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,8 @@ def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None:
reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]"
[builtins fixtures/primitives.pyi]

[case testNarrowingEqualityDisabledForCustomEquality]
[case testNarrowingEqualityCustomEqualityDisabled]
from typing import Literal, Union
from enum import Enum

class Custom:
def __eq__(self, other: object) -> bool: return True
Expand All @@ -834,15 +833,20 @@ class Default: pass

x1: Union[Custom, Literal[1], Literal[2]]
if x1 == 1:
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]"
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1]"
else:
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1] | Literal[2]"
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[2]"

x2: Union[Default, Literal[1], Literal[2]]
if x2 == 1:
reveal_type(x2) # N: Revealed type is "Literal[1]"
else:
reveal_type(x2) # N: Revealed type is "__main__.Default | Literal[2]"
[builtins fixtures/primitives.pyi]

[case testNarrowingEqualityCustomEqualityEnum]
from typing import Literal, Union
from enum import Enum

class CustomEnum(Enum):
A = 1
Expand All @@ -855,7 +859,7 @@ key: Literal[CustomEnum.A]
if x3 == key:
reveal_type(x3) # N: Revealed type is "__main__.CustomEnum"
else:
reveal_type(x3) # N: Revealed type is "__main__.CustomEnum"
reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]"

# For comparison, this narrows since we bypass __eq__
if x3 is key:
Expand All @@ -864,7 +868,7 @@ else:
reveal_type(x3) # N: Revealed type is "Literal[__main__.CustomEnum.B]"
[builtins fixtures/primitives.pyi]

[case testNarrowingEqualityDisabledForCustomEqualityChain]
[case testNarrowingEqualityCustomEqualityDisabledChainedComparison]
# flags: --strict-equality --warn-unreachable
from typing import Literal, Union

Expand All @@ -877,21 +881,13 @@ x: Literal[1, 2, None]
y: Custom
z: Default

# We could maybe try doing something clever, but for simplicity we
# treat the whole chain as contaminated and mostly disable narrowing.
#
# The only exception is that we do at least strip away the 'None'. We
# (perhaps optimistically) assume no custom class would be pathological
# enough to declare itself to be equal to None and so permit this narrowing,
# since it's often convenient in practice.
if 1 == x == y:
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2]"
reveal_type(x) # N: Revealed type is "Literal[1]"
reveal_type(y) # N: Revealed type is "__main__.Custom"
else:
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
reveal_type(x) # N: Revealed type is "Literal[2] | None"
reveal_type(y) # N: Revealed type is "__main__.Custom"

# No contamination here
if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
reveal_type(x) # E: Statement is unreachable
reveal_type(z)
Expand All @@ -900,6 +896,75 @@ else:
reveal_type(z) # N: Revealed type is "__main__.Default"
[builtins fixtures/primitives.pyi]

[case testNarrowingCustomEqualityLiteralElseBranch]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Literal

class Custom:
def __eq__(self, other: object) -> bool:
raise

def f(v: Custom | Literal["text"]) -> Custom | None:
if v == "text":
reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']"
return None
else:
reveal_type(v) # N: Revealed type is "__main__.Custom"
return v

def g(v: Custom | Literal["text"]) -> Custom | None:
if v != "text":
reveal_type(v) # N: Revealed type is "__main__.Custom"
return None
else:
reveal_type(v) # N: Revealed type is "__main__.Custom | Literal['text']"
return v # E: Incompatible return value type (got "Custom | Literal['text']", expected "Custom | None")
[builtins fixtures/primitives.pyi]

[case testNarrowingCustomEqualityUnion]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Any

def realistic(x: dict[str, Any]):
val = x.get("hey")
if val == 12:
reveal_type(val) # N: Revealed type is "Any | Literal[12]?"

def f1(x: Any | None):
if x == 12:
reveal_type(x) # N: Revealed type is "Any | Literal[12]?"

class Custom:
def __eq__(self, other: object) -> bool:
raise

def f2(x: Custom | None):
if x == 12:
reveal_type(x) # N: Revealed type is "__main__.Custom"
else:
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
[builtins fixtures/dict.pyi]

[case testNarrowingCustomEqualityUnionTypeTarget]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Any

class Custom:
def __eq__(self, other: object) -> bool:
raise

def f(x: Custom | None, y: int | None):
if x == y:
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
reveal_type(y) # N: Revealed type is "builtins.int | None"
else:
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
reveal_type(y) # N: Revealed type is "builtins.int | None"
[builtins fixtures/primitives.pyi]

[case testNarrowingUnreachableCases]
# flags: --strict-equality --warn-unreachable
from typing import Literal, Union
Expand Down Expand Up @@ -2157,7 +2222,7 @@ def f3(x: object) -> None:

def f4(x: int | Any) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "builtins.int | Any"
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any"
else:
reveal_type(x) # N: Revealed type is "builtins.int | Any"

Expand Down Expand Up @@ -2232,9 +2297,9 @@ def f5(x: E | str | int) -> None:

def f6(x: IE | Any) -> None:
if x == IE.X:
reveal_type(x) # N: Revealed type is "__main__.IE | Any"
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any"
else:
reveal_type(x) # N: Revealed type is "__main__.IE | Any"
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | Any"

def f7(x: IE | None) -> None:
if x == IE.X:
Expand Down Expand Up @@ -2321,7 +2386,7 @@ def f(x: str | int) -> None:
z = y
[builtins fixtures/primitives.pyi]

[case testConsistentNarrowingInWithCustomEq]
[case testConsistentNarrowingEqAndInWithCustomEq]
# flags: --python-version 3.10

# https://github.com/python/mypy/issues/17864
Expand All @@ -2339,11 +2404,17 @@ class C:
class D(C):
pass

def f(x: C) -> None:
def f1(x: C) -> None:
if x in [D(5)]:
reveal_type(x) # D # N: Revealed type is "__main__.C"

f(C(5))
f1(C(5))

def f2(x: C) -> None:
if x == D(5):
reveal_type(x) # D # N: Revealed type is "__main__.C"

f2(C(5))
[builtins fixtures/primitives.pyi]

[case testNarrowingTypeVarNone]
Expand Down