@@ -79,7 +79,9 @@ def copy(self) -> BranchState:
79
79
80
80
81
81
class BranchStatement :
82
- def __init__ (self , initial_state : BranchState ) -> None :
82
+ def __init__ (self , initial_state : BranchState | None = None ) -> None :
83
+ if initial_state is None :
84
+ initial_state = BranchState ()
83
85
self .initial_state = initial_state
84
86
self .branches : list [BranchState ] = [
85
87
BranchState (
@@ -171,7 +173,7 @@ class ScopeType(Enum):
171
173
Global = 1
172
174
Class = 2
173
175
Func = 3
174
- Generator = 3
176
+ Generator = 4
175
177
176
178
177
179
class Scope :
@@ -199,7 +201,7 @@ class DefinedVariableTracker:
199
201
200
202
def __init__ (self ) -> None :
201
203
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
202
- self .scopes : list [Scope ] = [Scope ([BranchStatement (BranchState () )], ScopeType .Global )]
204
+ self .scopes : list [Scope ] = [Scope ([BranchStatement ()], ScopeType .Global )]
203
205
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
204
206
# in things like try/except/finally statements.
205
207
self .disable_branch_skip = False
@@ -216,9 +218,11 @@ def _scope(self) -> Scope:
216
218
217
219
def enter_scope (self , scope_type : ScopeType ) -> None :
218
220
assert len (self ._scope ().branch_stmts ) > 0
219
- self .scopes .append (
220
- Scope ([BranchStatement (self ._scope ().branch_stmts [- 1 ].branches [- 1 ])], scope_type )
221
- )
221
+ initial_state = None
222
+ if scope_type == ScopeType .Generator :
223
+ # Generators are special because they inherit the outer scope.
224
+ initial_state = self ._scope ().branch_stmts [- 1 ].branches [- 1 ]
225
+ self .scopes .append (Scope ([BranchStatement (initial_state )], scope_type ))
222
226
223
227
def exit_scope (self ) -> None :
224
228
self .scopes .pop ()
@@ -342,13 +346,15 @@ def variable_may_be_undefined(self, name: str, context: Context) -> None:
342
346
def process_definition (self , name : str ) -> None :
343
347
# Was this name previously used? If yes, it's a used-before-definition error.
344
348
if not self .tracker .in_scope (ScopeType .Class ):
345
- # Errors in class scopes are caught by the semantic analyzer.
346
349
refs = self .tracker .pop_undefined_ref (name )
347
350
for ref in refs :
348
351
if self .loops :
349
352
self .variable_may_be_undefined (name , ref )
350
353
else :
351
354
self .var_used_before_def (name , ref )
355
+ else :
356
+ # Errors in class scopes are caught by the semantic analyzer.
357
+ pass
352
358
self .tracker .record_definition (name )
353
359
354
360
def visit_global_decl (self , o : GlobalDecl ) -> None :
@@ -415,17 +421,24 @@ def visit_match_stmt(self, o: MatchStmt) -> None:
415
421
416
422
def visit_func_def (self , o : FuncDef ) -> None :
417
423
self .process_definition (o .name )
418
- self .tracker .enter_scope (ScopeType .Func )
419
424
super ().visit_func_def (o )
420
- self .tracker .exit_scope ()
421
425
422
426
def visit_func (self , o : FuncItem ) -> None :
423
427
if o .is_dynamic () and not self .options .check_untyped_defs :
424
428
return
425
- if o .arguments is not None :
426
- for arg in o .arguments :
427
- self .tracker .record_definition (arg .variable .name )
428
- super ().visit_func (o )
429
+
430
+ args = o .arguments or []
431
+ # Process initializers (defaults) outside the function scope.
432
+ for arg in args :
433
+ if arg .initializer is not None :
434
+ arg .initializer .accept (self )
435
+
436
+ self .tracker .enter_scope (ScopeType .Func )
437
+ for arg in args :
438
+ self .process_definition (arg .variable .name )
439
+ super ().visit_var (arg .variable )
440
+ o .body .accept (self )
441
+ self .tracker .exit_scope ()
429
442
430
443
def visit_generator_expr (self , o : GeneratorExpr ) -> None :
431
444
self .tracker .enter_scope (ScopeType .Generator )
@@ -603,7 +616,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
603
616
super ().visit_starred_pattern (o )
604
617
605
618
def visit_name_expr (self , o : NameExpr ) -> None :
606
- if o .name in self .builtins :
619
+ if o .name in self .builtins and self . tracker . in_scope ( ScopeType . Global ) :
607
620
return
608
621
if self .tracker .is_possibly_undefined (o .name ):
609
622
# A variable is only defined in some branches.
0 commit comments