Skip to content

Commit 21d6773

Browse files
authored
Narrow for type expr comparisons to type exprs (#20639)
I refactored the logic for type(x) narrowing in #20634 , but left this piece out since it has material semantic impact Fixes #11952 Fixes #20275
1 parent 6d136ed commit 21d6773

File tree

2 files changed

+55
-34
lines changed

2 files changed

+55
-34
lines changed

mypy/checker.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6726,37 +6726,37 @@ def narrow_type_by_identity_equality(
67266726
else_map = {} # this is the big difference compared to the above
67276727
partial_type_maps.append((if_map, else_map))
67286728

6729-
exprs_in_type_calls = []
67306729
for i in expr_indices:
6731-
expr = operands[i]
6732-
if isinstance(expr, CallExpr) and is_type_call(expr):
6733-
exprs_in_type_calls.append(expr.args[0])
6734-
6735-
if exprs_in_type_calls:
6736-
for expr_in_type_call in exprs_in_type_calls:
6737-
for i in expr_indices:
6738-
expr = operands[i]
6739-
if isinstance(expr, CallExpr) and is_type_call(expr):
6740-
continue
6741-
6742-
current_type_range = self.get_isinstance_type(expr)
6743-
if_map, else_map = conditional_types_to_typemaps(
6744-
expr_in_type_call,
6745-
*self.conditional_types_with_intersection(
6746-
self.lookup_type(expr_in_type_call),
6747-
current_type_range,
6748-
expr_in_type_call,
6749-
),
6750-
)
6730+
type_expr = operands[i]
6731+
if (
6732+
isinstance(type_expr, CallExpr)
6733+
and refers_to_fullname(type_expr.callee, "builtins.type")
6734+
and len(type_expr.args) == 1
6735+
):
6736+
expr_in_type_expr = type_expr.args[0]
6737+
else:
6738+
continue
6739+
for j in expr_indices:
6740+
if i == j:
6741+
continue
6742+
expr = operands[j]
6743+
6744+
current_type_range = self.get_isinstance_type(expr)
6745+
if_map, else_map = conditional_types_to_typemaps(
6746+
expr_in_type_expr,
6747+
*self.conditional_types_with_intersection(
6748+
self.lookup_type(expr_in_type_expr), current_type_range, expr_in_type_expr
6749+
),
6750+
)
67516751

6752-
is_final = (
6753-
expr.node.is_final
6754-
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo)
6755-
else False
6756-
)
6757-
if not is_final:
6758-
else_map = {}
6759-
partial_type_maps.append((if_map, else_map))
6752+
is_final = (
6753+
expr.node.is_final
6754+
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo)
6755+
else False
6756+
)
6757+
if not is_final:
6758+
else_map = {}
6759+
partial_type_maps.append((if_map, else_map))
67606760

67616761
# We will not have duplicate entries in our type maps if we only have two operands,
67626762
# so we can skip running meets on the intersections
@@ -8562,11 +8562,6 @@ def has_custom_eq_checks(t: Type) -> bool:
85628562
)
85638563

85648564

8565-
def is_type_call(expr: CallExpr) -> bool:
8566-
"""Is expr a call to type with one argument?"""
8567-
return refers_to_fullname(expr.callee, "builtins.type") and len(expr.args) == 1
8568-
8569-
85708565
def convert_to_typetype(type_map: TypeMap) -> TypeMap:
85718566
converted_type_map: dict[Expression, Type] = {}
85728567
if type_map is None:

test-data/unit/check-narrowing.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3021,6 +3021,32 @@ if type(x) is not int:
30213021
else:
30223022
reveal_type(x) # N: Revealed type is "builtins.int"
30233023

3024+
[case testTypeNarrowingAgainstType]
3025+
# flags: --strict-equality --warn-unreachable
3026+
class A:
3027+
def foo(self, x: object) -> None:
3028+
reveal_type(self) # N: Revealed type is "__main__.A"
3029+
reveal_type(x) # N: Revealed type is "builtins.object"
3030+
if type(self) is type(x):
3031+
reveal_type(self) # N: Revealed type is "__main__.A"
3032+
reveal_type(x) # N: Revealed type is "__main__.A"
3033+
else:
3034+
reveal_type(self) # N: Revealed type is "__main__.A"
3035+
reveal_type(x) # N: Revealed type is "builtins.object"
3036+
if type(self) == type(x):
3037+
reveal_type(self) # N: Revealed type is "__main__.A"
3038+
reveal_type(x) # N: Revealed type is "__main__.A"
3039+
else:
3040+
reveal_type(self) # N: Revealed type is "__main__.A"
3041+
reveal_type(x) # N: Revealed type is "builtins.object"
3042+
3043+
class B:
3044+
y: int
3045+
3046+
def __eq__(self, other: object) -> bool:
3047+
return type(other) is type(self) and other.y == self.y
3048+
[builtins fixtures/primitives.pyi]
3049+
30243050
[case testNarrowInElseCaseIfFinal]
30253051
# flags: --strict-equality --warn-unreachable
30263052
from typing import final, Union

0 commit comments

Comments
 (0)