diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 4300626ecd9f..7d87315c23ad 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -18,6 +18,7 @@ IfStmt, ListExpr, Lvalue, + MatchStmt, NameExpr, RaiseStmt, ReturnStmt, @@ -25,6 +26,8 @@ WhileStmt, WithStmt, ) +from mypy.patterns import AsPattern, StarredPattern +from mypy.reachability import ALWAYS_TRUE, infer_pattern_value from mypy.traverser import ExtendedTraverserVisitor from mypy.types import Type, UninhabitedType @@ -57,11 +60,19 @@ class BranchStatement: def __init__(self, initial_state: BranchState) -> None: self.initial_state = initial_state self.branches: list[BranchState] = [ - BranchState(must_be_defined=self.initial_state.must_be_defined) + BranchState( + must_be_defined=self.initial_state.must_be_defined, + may_be_defined=self.initial_state.may_be_defined, + ) ] def next_branch(self) -> None: - self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined)) + self.branches.append( + BranchState( + must_be_defined=self.initial_state.must_be_defined, + may_be_defined=self.initial_state.may_be_defined, + ) + ) def record_definition(self, name: str) -> None: assert len(self.branches) > 0 @@ -198,6 +209,21 @@ def visit_if_stmt(self, o: IfStmt) -> None: o.else_body.accept(self) self.tracker.end_branch_statement() + def visit_match_stmt(self, o: MatchStmt) -> None: + self.tracker.start_branch_statement() + o.subject.accept(self) + for i in range(len(o.patterns)): + pattern = o.patterns[i] + pattern.accept(self) + guard = o.guards[i] + if guard is not None: + guard.accept(self) + o.bodies[i].accept(self) + is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE + if not is_catchall: + self.tracker.next_branch() + self.tracker.end_branch_statement() + def visit_func_def(self, o: FuncDef) -> None: self.tracker.enter_scope() super().visit_func_def(o) @@ -270,6 +296,16 @@ def visit_while_stmt(self, o: WhileStmt) -> None: o.else_body.accept(self) self.tracker.end_branch_statement() + def visit_as_pattern(self, o: AsPattern) -> None: + if o.name is not None: + self.process_lvalue(o.name) + super().visit_as_pattern(o) + + def visit_starred_pattern(self, o: StarredPattern) -> None: + if o.capture is not None: + self.process_lvalue(o.capture) + super().visit_starred_pattern(o) + def visit_name_expr(self, o: NameExpr) -> None: if self.tracker.is_possibly_undefined(o.name): self.msg.variable_may_be_undefined(o.name, o) diff --git a/test-data/unit/check-partially-defined.test b/test-data/unit/check-partially-defined.test index 6bb5a65232eb..d456568c1131 100644 --- a/test-data/unit/check-partially-defined.test +++ b/test-data/unit/check-partially-defined.test @@ -18,6 +18,13 @@ else: z = a + 1 # E: Name "a" may be undefined +[case testUsedInIf] +# flags: --enable-error-code partially-defined +if int(): + y = 1 +if int(): + x = y # E: Name "y" may be undefined + [case testDefinedInAllBranches] # flags: --enable-error-code partially-defined if int(): diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 5ac34025384c..1548d5dadcfd 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1725,3 +1725,66 @@ def my_func(pairs: Iterable[tuple[S, S]]) -> None: reveal_type(pair) # N: Revealed type is "Tuple[builtins.int, builtins.int]" \ # N: Revealed type is "Tuple[builtins.str, builtins.str]" [builtins fixtures/tuple.pyi] + +[case testPartiallyDefinedMatch] +# flags: --enable-error-code partially-defined +def f0(x: int | str) -> int: + match x: + case int(): + y = 1 + return y # E: Name "y" may be undefined + +def f1(a: object) -> None: + match a: + case [y]: pass + case _: + y = 1 + x = 2 + z = y + z = x # E: Name "x" may be undefined + +def f2(a: object) -> None: + match a: + case [[y] as x]: pass + case {"k1": 1, "k2": x, "k3": y}: pass + case [0, *x]: + y = 2 + case _: + y = 1 + x = [2] + z = x + z = y + +def f3(a: object) -> None: + y = 1 + match a: + case [x]: + y = 2 + # Note the missing `case _:` + z = x # E: Name "x" may be undefined + z = y + +def f4(a: object) -> None: + y = 1 + match a: + case [x]: + y = 2 + case _: + assert False, "unsupported" + z = x + z = y + +def f5(a: object) -> None: + match a: + case tuple(x): pass + case _: + return + y = x + +def f6(a: object) -> None: + if int(): + y = 1 + match a: + case _ if y is not None: # E: Name "y" may be undefined + pass +[builtins fixtures/tuple.pyi]