Skip to content

Prevent crashing when match arms use name of existing callable #18449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5402,17 +5402,21 @@ def _get_recursive_sub_patterns_map(

return sub_patterns_map

def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
def infer_variable_types_from_type_maps(
self, type_maps: list[TypeMap]
) -> dict[SymbolNode, Type]:
# Type maps may contain variables inherited from previous code which are not
# necessary `Var`s (e.g. a function defined earlier with the same name).
all_captures: dict[SymbolNode, list[tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
if tm is not None:
for expr, typ in tm.items():
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
assert node is not None
all_captures[node].append((expr, typ))

inferred_types: dict[Var, Type] = {}
inferred_types: dict[SymbolNode, Type] = {}
for var, captures in all_captures.items():
already_exists = False
types: list[Type] = []
Expand All @@ -5436,16 +5440,19 @@ def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[
new_type = UnionType.make_union(types)
# Infer the union type at the first occurrence
first_occurrence, _ = captures[0]
# If it didn't exist before ``match``, it's a Var.
assert isinstance(var, Var)
inferred_types[var] = new_type
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types

def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var, Type]) -> None:
def remove_capture_conflicts(
self, type_map: TypeMap, inferred_types: dict[SymbolNode, Type]
) -> None:
if type_map:
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]

Expand Down
51 changes: 51 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -2471,3 +2471,54 @@ def nested_in_dict(d: dict[str, Any]) -> int:
return 0

[builtins fixtures/dict.pyi]

[case testMatchRebindsOuterFunctionName]
# flags: --warn-unreachable
from typing_extensions import Literal

def x() -> tuple[Literal["test"]]: ...

match x():
case (x,) if x == "test": # E: Incompatible types in capture pattern (pattern captures type "Literal['test']", variable has type "Callable[[], Tuple[Literal['test']]]")
reveal_type(x) # N: Revealed type is "def () -> Tuple[Literal['test']]"
case foo:
foo

[builtins fixtures/dict.pyi]

[case testMatchRebindsInnerFunctionName]
# flags: --warn-unreachable
class Some:
value: int | str
__match_args__ = ("value",)

def fn1(x: Some | int | str) -> None:
match x:
case int():
def value():
return 1
reveal_type(value) # N: Revealed type is "def () -> Any"
case str():
def value():
return 1
reveal_type(value) # N: Revealed type is "def () -> Any"
case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], Any]")
pass

def fn2(x: Some | int | str) -> None:
match x:
case int():
def value() -> str:
return ""
reveal_type(value) # N: Revealed type is "def () -> builtins.str"
case str():
def value() -> int: # E: All conditional function variants must have identical signatures \
# N: Original: \
# N: def value() -> str \
# N: Redefinition: \
# N: def value() -> int
return 1
reveal_type(value) # N: Revealed type is "def () -> builtins.str"
case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], str]")
pass
[builtins fixtures/dict.pyi]
Loading