Skip to content

Commit a618110

Browse files
authored
[used before def] improve handling of global definitions in local scopes (python#14517)
While working on python#14483, we discovered that variable inheritance didn't work quite right. In particular, functions would inherit variables from outer scope. On the surface, this is what you want but actually, they only inherit the scope if there isn't a colliding definition within that scope. Here's an example: ```python class c: pass def f0() -> None: s = c() # UnboundLocalError is raised when this code is executed. class c: pass def f1() -> None: s = c() # No error. ``` This PR also fixes issues with builtins (exactly the same example as above but instead of `c` we have a builtin). Fixes python#14213 (as much as is reasonable to do)
1 parent c245e91 commit a618110

File tree

4 files changed

+148
-78
lines changed

4 files changed

+148
-78
lines changed

mypy/partially_defined.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def copy(self) -> BranchState:
7979

8080

8181
class BranchStatement:
82-
def __init__(self, initial_state: BranchState) -> None:
82+
def __init__(self, initial_state: BranchState | None = None) -> None:
83+
if initial_state is None:
84+
initial_state = BranchState()
8385
self.initial_state = initial_state
8486
self.branches: list[BranchState] = [
8587
BranchState(
@@ -171,7 +173,7 @@ class ScopeType(Enum):
171173
Global = 1
172174
Class = 2
173175
Func = 3
174-
Generator = 3
176+
Generator = 4
175177

176178

177179
class Scope:
@@ -199,7 +201,7 @@ class DefinedVariableTracker:
199201

200202
def __init__(self) -> None:
201203
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
202-
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())], ScopeType.Global)]
204+
self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
203205
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
204206
# in things like try/except/finally statements.
205207
self.disable_branch_skip = False
@@ -216,9 +218,11 @@ def _scope(self) -> Scope:
216218

217219
def enter_scope(self, scope_type: ScopeType) -> None:
218220
assert len(self._scope().branch_stmts) > 0
219-
self.scopes.append(
220-
Scope([BranchStatement(self._scope().branch_stmts[-1].branches[-1])], scope_type)
221-
)
221+
initial_state = None
222+
if scope_type == ScopeType.Generator:
223+
# Generators are special because they inherit the outer scope.
224+
initial_state = self._scope().branch_stmts[-1].branches[-1]
225+
self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))
222226

223227
def exit_scope(self) -> None:
224228
self.scopes.pop()
@@ -342,13 +346,15 @@ def variable_may_be_undefined(self, name: str, context: Context) -> None:
342346
def process_definition(self, name: str) -> None:
343347
# Was this name previously used? If yes, it's a used-before-definition error.
344348
if not self.tracker.in_scope(ScopeType.Class):
345-
# Errors in class scopes are caught by the semantic analyzer.
346349
refs = self.tracker.pop_undefined_ref(name)
347350
for ref in refs:
348351
if self.loops:
349352
self.variable_may_be_undefined(name, ref)
350353
else:
351354
self.var_used_before_def(name, ref)
355+
else:
356+
# Errors in class scopes are caught by the semantic analyzer.
357+
pass
352358
self.tracker.record_definition(name)
353359

354360
def visit_global_decl(self, o: GlobalDecl) -> None:
@@ -415,17 +421,24 @@ def visit_match_stmt(self, o: MatchStmt) -> None:
415421

416422
def visit_func_def(self, o: FuncDef) -> None:
417423
self.process_definition(o.name)
418-
self.tracker.enter_scope(ScopeType.Func)
419424
super().visit_func_def(o)
420-
self.tracker.exit_scope()
421425

422426
def visit_func(self, o: FuncItem) -> None:
423427
if o.is_dynamic() and not self.options.check_untyped_defs:
424428
return
425-
if o.arguments is not None:
426-
for arg in o.arguments:
427-
self.tracker.record_definition(arg.variable.name)
428-
super().visit_func(o)
429+
430+
args = o.arguments or []
431+
# Process initializers (defaults) outside the function scope.
432+
for arg in args:
433+
if arg.initializer is not None:
434+
arg.initializer.accept(self)
435+
436+
self.tracker.enter_scope(ScopeType.Func)
437+
for arg in args:
438+
self.process_definition(arg.variable.name)
439+
super().visit_var(arg.variable)
440+
o.body.accept(self)
441+
self.tracker.exit_scope()
429442

430443
def visit_generator_expr(self, o: GeneratorExpr) -> None:
431444
self.tracker.enter_scope(ScopeType.Generator)
@@ -603,7 +616,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
603616
super().visit_starred_pattern(o)
604617

605618
def visit_name_expr(self, o: NameExpr) -> None:
606-
if o.name in self.builtins:
619+
if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
607620
return
608621
if self.tracker.is_possibly_undefined(o.name):
609622
# A variable is only defined in some branches.

mypyc/test-data/run-sets.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def test_in_set() -> None:
141141
assert main_set(item), f"{item!r} should be in set_main"
142142
assert not main_negated_set(item), item
143143

144-
assert non_final_name_set(non_const)
145144
global non_const
145+
assert non_final_name_set(non_const)
146146
non_const = "updated"
147147
assert non_final_name_set("updated")
148148

test-data/unit/check-functions.test

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -491,62 +491,61 @@ if int():
491491

492492
[case testDefaultArgumentExpressions]
493493
import typing
494+
class B: pass
495+
class A: pass
496+
494497
def f(x: 'A' = A()) -> None:
495498
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
496499
a = x # type: A
497-
498-
class B: pass
499-
class A: pass
500500
[out]
501501

502502
[case testDefaultArgumentExpressions2]
503503
import typing
504-
def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
505-
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
506-
a = x # type: A
507-
508504
class B: pass
509505
class A: pass
510506

507+
def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
508+
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
509+
a = x # type: A
511510
[case testDefaultArgumentExpressionsGeneric]
512511
from typing import TypeVar
513512
T = TypeVar('T', bound='A')
514-
def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
515-
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
516-
a = x # type: A
517513

518514
class B: pass
519515
class A: pass
520516

517+
def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
518+
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
519+
a = x # type: A
521520
[case testDefaultArgumentsWithSubtypes]
522521
import typing
522+
class A: pass
523+
class B(A): pass
524+
523525
def f(x: 'B' = A()) -> None: # E: Incompatible default for argument "x" (default has type "A", argument has type "B")
524526
pass
525527
def g(x: 'A' = B()) -> None:
526528
pass
527-
528-
class A: pass
529-
class B(A): pass
530529
[out]
531530

532531
[case testMultipleDefaultArgumentExpressions]
533532
import typing
533+
class A: pass
534+
class B: pass
535+
534536
def f(x: 'A' = B(), y: 'B' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
535537
pass
536538
def h(x: 'A' = A(), y: 'B' = B()) -> None:
537539
pass
538-
539-
class A: pass
540-
class B: pass
541540
[out]
542541

543542
[case testMultipleDefaultArgumentExpressions2]
544543
import typing
545-
def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
546-
pass
547-
548544
class A: pass
549545
class B: pass
546+
547+
def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
548+
pass
550549
[out]
551550

552551
[case testDefaultArgumentsAndSignatureAsComment]
@@ -2612,7 +2611,7 @@ def f() -> int: ...
26122611
[case testLambdaDefaultTypeErrors]
26132612
lambda a=(1 + 'asdf'): a # E: Unsupported operand types for + ("int" and "str")
26142613
lambda a=nonsense: a # E: Name "nonsense" is not defined
2615-
def f(x: int = i): # E: Name "i" is not defined # E: Name "i" is used before definition
2614+
def f(x: int = i): # E: Name "i" is not defined
26162615
i = 42
26172616

26182617
[case testRevealTypeOfCallExpressionReturningNoneWorks]

0 commit comments

Comments
 (0)