Skip to content

Commit b91f53c

Browse files
authored
Narrow types based on collection containment (#20602)
We've wanted to do this for a long time, but previous attempt was reverted due to issues like #17841 , #17864 , #17869 Following #20492 we should now be in a position to do this narrowing Fixes #3229 Fixes #20234 Fixes #18208 Fixes #16774
1 parent 3703273 commit b91f53c

File tree

6 files changed

+120
-18
lines changed

6 files changed

+120
-18
lines changed

mypy/checker.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6599,6 +6599,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65996599

66006600
partial_type_maps = []
66016601
for operator, expr_indices in simplified_operator_list:
6602+
if_map: TypeMap
6603+
else_map: TypeMap
6604+
66026605
if operator in {"is", "is not", "==", "!="}:
66036606
if_map, else_map = self.equality_type_narrowing_helper(
66046607
node,
@@ -6614,14 +6617,24 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66146617
item_type = operand_types[left_index]
66156618
iterable_type = operand_types[right_index]
66166619

6617-
if_map, else_map = {}, {}
6620+
if_map = {}
6621+
else_map = {}
66186622

66196623
if left_index in narrowable_operand_index_to_hash:
6620-
# We only try and narrow away 'None' for now
6621-
if is_overlapping_none(item_type):
6622-
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6624+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6625+
if collection_item_type is not None:
6626+
if_map, else_map = self.narrow_type_by_equality(
6627+
"==",
6628+
operands=[operands[left_index], operands[right_index]],
6629+
operand_types=[item_type, collection_item_type],
6630+
expr_indices=[left_index, right_index],
6631+
narrowable_indices={0},
6632+
)
6633+
6634+
# We only try and narrow away 'None' for now
66236635
if (
6624-
collection_item_type is not None
6636+
if_map is not None
6637+
and is_overlapping_none(item_type)
66256638
and not is_overlapping_none(collection_item_type)
66266639
and not (
66276640
isinstance(collection_item_type, Instance)
@@ -6638,11 +6651,11 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66386651
expr = operands[right_index]
66396652
if if_type is None:
66406653
if_map = None
6641-
else:
6654+
elif if_map is not None:
66426655
if_map[expr] = if_type
66436656
if else_type is None:
66446657
else_map = None
6645-
else:
6658+
elif else_map is not None:
66466659
else_map[expr] = else_type
66476660

66486661
else:

mypy/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def infer_constraints_for_callable(
123123
param_spec = callee.param_spec()
124124
param_spec_arg_types = []
125125
param_spec_arg_names = []
126-
param_spec_arg_kinds = []
126+
param_spec_arg_kinds: list[ArgKind] = []
127127

128128
incomplete_star_mapping = False
129129
for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`?

mypyc/irbuild/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Final, Literal, TypedDict, cast
5+
from typing import Any, Final, Literal, TypedDict
66
from typing_extensions import NotRequired
77

88
from mypy.nodes import (
@@ -138,7 +138,6 @@ def get_mypyc_attrs(
138138

139139
def set_mypyc_attr(key: str, value: Any, line: int) -> None:
140140
if key in MYPYC_ATTRS:
141-
key = cast(MypycAttr, key)
142141
attrs[key] = value
143142
lines[key] = line
144143
else:

test-data/unit/check-narrowing.test

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,13 +1373,13 @@ else:
13731373
reveal_type(val) # N: Revealed type is "None"
13741374

13751375
if val in (None,):
1376-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1376+
reveal_type(val) # N: Revealed type is "None"
13771377
else:
1378-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1378+
reveal_type(val) # N: Revealed type is "__main__.A"
13791379
if val not in (None,):
1380-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1380+
reveal_type(val) # N: Revealed type is "__main__.A"
13811381
else:
1382-
reveal_type(val) # N: Revealed type is "__main__.A | None"
1382+
reveal_type(val) # N: Revealed type is "None"
13831383

13841384
class Hmm:
13851385
def __eq__(self, other) -> bool: ...
@@ -2315,9 +2315,8 @@ def f(x: str | int) -> None:
23152315
y = x
23162316

23172317
if x in ["x"]:
2318-
# TODO: we should fix this reveal https://github.com/python/mypy/issues/3229
2319-
reveal_type(x) # N: Revealed type is "builtins.str | builtins.int"
2320-
y = x # E: Incompatible types in assignment (expression has type "str | int", variable has type "str")
2318+
reveal_type(x) # N: Revealed type is "builtins.str"
2319+
y = x
23212320
z = x
23222321
z = y
23232322
[builtins fixtures/primitives.pyi]
@@ -2838,3 +2837,86 @@ class X:
28382837
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
28392838
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
28402839
[builtins fixtures/dict.pyi]
2840+
2841+
2842+
[case testTypeNarrowingStringInLiteralContainer]
2843+
# flags: --strict-equality --warn-unreachable
2844+
from typing import Literal
2845+
2846+
def narrow_tuple(x: str, t: tuple[Literal['a', 'b']]):
2847+
if x in t:
2848+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2849+
else:
2850+
reveal_type(x) # N: Revealed type is "builtins.str"
2851+
2852+
if x not in t:
2853+
reveal_type(x) # N: Revealed type is "builtins.str"
2854+
else:
2855+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2856+
2857+
def narrow_homo_tuple(x: str, t: tuple[Literal['a', 'b'], ...]):
2858+
if x in t:
2859+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2860+
else:
2861+
reveal_type(x) # N: Revealed type is "builtins.str"
2862+
2863+
def narrow_list(x: str, t: list[Literal['a', 'b']]):
2864+
if x in t:
2865+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2866+
else:
2867+
reveal_type(x) # N: Revealed type is "builtins.str"
2868+
2869+
def narrow_set(x: str, t: set[Literal['a', 'b']]):
2870+
if x in t:
2871+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2872+
else:
2873+
reveal_type(x) # N: Revealed type is "builtins.str"
2874+
[builtins fixtures/primitives.pyi]
2875+
2876+
2877+
[case testNarrowingLiteralInLiteralContainer]
2878+
# flags: --strict-equality --warn-unreachable
2879+
from typing import Literal
2880+
2881+
def narrow_tuple(x: Literal['c'], overlap: list[Literal['b', 'c']], no_overlap: list[Literal['a', 'b']]):
2882+
if x in overlap:
2883+
reveal_type(x) # N: Revealed type is "Literal['c']"
2884+
else:
2885+
reveal_type(x) # N: Revealed type is "Literal['c']"
2886+
2887+
if x in no_overlap:
2888+
reveal_type(x) # N: Revealed type is "Literal['c']"
2889+
else:
2890+
reveal_type(x) # N: Revealed type is "Literal['c']"
2891+
[builtins fixtures/tuple.pyi]
2892+
2893+
[case testTypeNarrowingUnionInContainer]
2894+
# flags: --strict-equality --warn-unreachable
2895+
from typing import Union, Literal
2896+
2897+
def f1(x: Union[str, float], t1: list[Literal['a', 'b']], t2: list[str]):
2898+
if x in t1:
2899+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
2900+
else:
2901+
reveal_type(x) # N: Revealed type is "builtins.str | builtins.float"
2902+
2903+
if x in t2:
2904+
reveal_type(x) # N: Revealed type is "builtins.str"
2905+
else:
2906+
reveal_type(x) # N: Revealed type is "builtins.str | builtins.float"
2907+
[builtins fixtures/primitives.pyi]
2908+
2909+
[case testNarrowAnyWithEqualityOrContainment]
2910+
# https://github.com/python/mypy/issues/17841
2911+
from typing import Any
2912+
2913+
def f1(x: Any) -> None:
2914+
if x is not None and x not in ["x"]:
2915+
return
2916+
reveal_type(x) # N: Revealed type is "Any"
2917+
2918+
def f2(x: Any) -> None:
2919+
if x is not None and x != "x":
2920+
return
2921+
reveal_type(x) # N: Revealed type is "Any"
2922+
[builtins fixtures/tuple.pyi]

test-data/unit/fixtures/narrowing.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Builtins stub used in check-narrowing test cases.
2-
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
2+
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable
33

44

55
Tco = TypeVar('Tco', covariant=True)
@@ -15,6 +15,13 @@ class function: pass
1515
class ellipsis: pass
1616
class int: pass
1717
class str: pass
18+
class float: pass
1819
class dict(Generic[KT, VT]): pass
1920

2021
def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass
22+
23+
class list(Sequence[Tco]):
24+
def __contains__(self, other: object) -> bool: pass
25+
class set(Iterable[Tco], Generic[Tco]):
26+
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
27+
def __contains__(self, item: object) -> bool: pass

test-data/unit/fixtures/primitives.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class dict(Mapping[T, V]):
6363
def __iter__(self) -> Iterator[T]: pass
6464
class set(Iterable[T]):
6565
def __iter__(self) -> Iterator[T]: pass
66+
def __contains__(self, o: object, /) -> bool: pass
6667
class frozenset(Iterable[T]):
6768
def __iter__(self) -> Iterator[T]: pass
6869
class function: pass

0 commit comments

Comments
 (0)