Skip to content

Support match statement in partially defined check #13860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
IfStmt,
ListExpr,
Lvalue,
MatchStmt,
NameExpr,
RaiseStmt,
ReturnStmt,
TupleExpr,
WhileStmt,
WithStmt,
)
from mypy.patterns import AsPattern, StarredPattern
from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import Type, UninhabitedType

Expand Down Expand Up @@ -57,11 +60,19 @@ class BranchStatement:
def __init__(self, initial_state: BranchState) -> None:
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(must_be_defined=self.initial_state.must_be_defined)
BranchState(
must_be_defined=self.initial_state.must_be_defined,
may_be_defined=self.initial_state.may_be_defined,
)
]

def next_branch(self) -> None:
self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined))
self.branches.append(
BranchState(
must_be_defined=self.initial_state.must_be_defined,
may_be_defined=self.initial_state.may_be_defined,
)
)

def record_definition(self, name: str) -> None:
assert len(self.branches) > 0
Expand Down Expand Up @@ -198,6 +209,21 @@ def visit_if_stmt(self, o: IfStmt) -> None:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_match_stmt(self, o: MatchStmt) -> None:
self.tracker.start_branch_statement()
o.subject.accept(self)
for i in range(len(o.patterns)):
pattern = o.patterns[i]
pattern.accept(self)
guard = o.guards[i]
if guard is not None:
guard.accept(self)
o.bodies[i].accept(self)
is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE
if not is_catchall:
self.tracker.next_branch()
self.tracker.end_branch_statement()

def visit_func_def(self, o: FuncDef) -> None:
self.tracker.enter_scope()
super().visit_func_def(o)
Expand Down Expand Up @@ -270,6 +296,16 @@ def visit_while_stmt(self, o: WhileStmt) -> None:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_as_pattern(self, o: AsPattern) -> None:
if o.name is not None:
self.process_lvalue(o.name)
super().visit_as_pattern(o)

def visit_starred_pattern(self, o: StarredPattern) -> None:
if o.capture is not None:
self.process_lvalue(o.capture)
super().visit_starred_pattern(o)

def visit_name_expr(self, o: NameExpr) -> None:
if self.tracker.is_possibly_undefined(o.name):
self.msg.variable_may_be_undefined(o.name, o)
Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-partially-defined.test
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ else:

z = a + 1 # E: Name "a" may be undefined

[case testUsedInIf]
# flags: --enable-error-code partially-defined
if int():
y = 1
if int():
x = y # E: Name "y" may be undefined

[case testDefinedInAllBranches]
# flags: --enable-error-code partially-defined
if int():
Expand Down
63 changes: 63 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1725,3 +1725,66 @@ def my_func(pairs: Iterable[tuple[S, S]]) -> None:
reveal_type(pair) # N: Revealed type is "Tuple[builtins.int, builtins.int]" \
# N: Revealed type is "Tuple[builtins.str, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testPartiallyDefinedMatch]
# flags: --enable-error-code partially-defined
def f0(x: int | str) -> int:
match x:
case int():
y = 1
return y # E: Name "y" may be undefined

def f1(a: object) -> None:
match a:
case [y]: pass
case _:
y = 1
x = 2
z = y
z = x # E: Name "x" may be undefined

def f2(a: object) -> None:
match a:
case [[y] as x]: pass
case {"k1": 1, "k2": x, "k3": y}: pass
case [0, *x]:
y = 2
case _:
y = 1
x = [2]
z = x
z = y

def f3(a: object) -> None:
y = 1
match a:
case [x]:
y = 2
# Note the missing `case _:`
z = x # E: Name "x" may be undefined
z = y

def f4(a: object) -> None:
y = 1
match a:
case [x]:
y = 2
case _:
assert False, "unsupported"
z = x
z = y

def f5(a: object) -> None:
match a:
case tuple(x): pass
case _:
return
y = x

def f6(a: object) -> None:
if int():
y = 1
match a:
case _ if y is not None: # E: Name "y" may be undefined
pass
[builtins fixtures/tuple.pyi]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't generate an error, though I think it should:

def f(x: int | str) -> int:
    match x:
        case int():
            y = 1
    return y  # no error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exact example you provided generates the error for me. Can you double check how you ran this?

I've added this exact example into the test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was running the wrong code. The example works just fine. However, this related case seems to not work as expected:

def f(x: int | str) -> int:
    match x:
        case int():
            y = 1
        case str():
            y = 2
    return y  # error: Name "y" may be undefined

Since the match statement covers all the union items, there shouldn't be an error. Similarly, this generates a false positive:

def f2(x: int | str) -> int:
    if isinstance(x, int):
        y = 1
    elif isinstance(x, str):
        y = 2
    return y  # error: Name "y" may be undefined

However, mypy doesn't complain about a missing return statement in these related examples, so mypy can already detect if all union items are covered:

def f3(x: int | str) -> int:
    if isinstance(x, int):
        return 1
    elif isinstance(x, str):
        return 2

def f4(x: int | str) -> int:
    match x:
        case int():
            return 1
        case str():
            return 2

Since this also affects isinstance checks, this may not be directly related to this PR, and fixed in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find!

Agreed, separate PR is probably a better place to do this. I created #13926.