Skip to content

Commit 17271e5

Browse files
Fix narrowing on match with function subject (#16503)
Fixes #12998 mypy can't narrow match statements with functions subjects because the callexpr node is not a literal node. This adds a 'dummy' literal node that the match statement visitor can use to do the type narrowing. The python grammar describes the the match subject as a named expression so this uses that nameexpr node as it's literal. --------- Co-authored-by: hauntsaninja <[email protected]>
1 parent bfbac5e commit 17271e5

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

mypy/checker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5053,6 +5053,19 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
50535053
return
50545054

50555055
def visit_match_stmt(self, s: MatchStmt) -> None:
5056+
named_subject: Expression
5057+
if isinstance(s.subject, CallExpr):
5058+
# Create a dummy subject expression to handle cases where a match statement's subject
5059+
# is not a literal value. This lets us correctly narrow types and check exhaustivity
5060+
# This is hack!
5061+
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
5062+
name = "dummy-match-" + id
5063+
v = Var(name)
5064+
named_subject = NameExpr(name)
5065+
named_subject.node = v
5066+
else:
5067+
named_subject = s.subject
5068+
50565069
with self.binder.frame_context(can_skip=False, fall_through=0):
50575070
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
50585071

@@ -5071,7 +5084,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
50715084
# The second pass narrows down the types and type checks bodies.
50725085
for p, g, b in zip(s.patterns, s.guards, s.bodies):
50735086
current_subject_type = self.expr_checker.narrow_type_from_binder(
5074-
s.subject, subject_type
5087+
named_subject, subject_type
50755088
)
50765089
pattern_type = self.pattern_checker.accept(p, current_subject_type)
50775090
with self.binder.frame_context(can_skip=True, fall_through=2):
@@ -5082,7 +5095,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
50825095
else_map: TypeMap = {}
50835096
else:
50845097
pattern_map, else_map = conditional_types_to_typemaps(
5085-
s.subject, pattern_type.type, pattern_type.rest_type
5098+
named_subject, pattern_type.type, pattern_type.rest_type
50865099
)
50875100
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
50885101
self.push_type_map(pattern_map)
@@ -5110,7 +5123,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
51105123
and expr.fullname == case_target.fullname
51115124
):
51125125
continue
5113-
type_map[s.subject] = type_map[expr]
5126+
type_map[named_subject] = type_map[expr]
51145127

51155128
self.push_type_map(guard_map)
51165129
self.accept(b)

test-data/unit/check-python310.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,21 @@ match m:
11391139

11401140
reveal_type(a) # N: Revealed type is "builtins.str"
11411141

1142+
[case testMatchCapturePatternFromFunctionReturningUnion]
1143+
def func1(arg: bool) -> str | int: ...
1144+
def func2(arg: bool) -> bytes | int: ...
1145+
1146+
def main() -> None:
1147+
match func1(True):
1148+
case str(a):
1149+
match func2(True):
1150+
case c:
1151+
reveal_type(a) # N: Revealed type is "builtins.str"
1152+
reveal_type(c) # N: Revealed type is "Union[builtins.bytes, builtins.int]"
1153+
reveal_type(a) # N: Revealed type is "builtins.str"
1154+
case a:
1155+
reveal_type(a) # N: Revealed type is "builtins.int"
1156+
11421157
-- Guards --
11431158

11441159
[case testMatchSimplePatternGuard]

0 commit comments

Comments
 (0)