11
11
12
12
import ast
13
13
import tokenize
14
- from typing import Any , Generator , Iterable , List , Optional , Tuple , Type , Union
14
+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Type , Union
15
15
16
16
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17
17
__version__ = "22.7.6"
18
18
19
-
20
19
Error = Tuple [int , int , str , Type [Any ]]
20
+
21
+
21
22
checkpoint_node_types = (ast .Await , ast .AsyncFor , ast .AsyncWith )
22
23
cancel_scope_names = (
23
24
"fail_after" ,
@@ -39,13 +40,13 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
39
40
class Flake8TrioVisitor (ast .NodeVisitor ):
40
41
def __init__ (self ):
41
42
super ().__init__ ()
42
- self .problems : List [Error ] = []
43
+ self ._problems : List [Error ] = []
43
44
44
45
@classmethod
45
- def run (cls , tree : ast .AST ) -> Generator [Error , None , None ]:
46
+ def run (cls , tree : ast .AST ) -> Iterable [Error ]:
46
47
visitor = cls ()
47
48
visitor .visit (tree )
48
- yield from visitor .problems
49
+ yield from visitor ._problems
49
50
50
51
def visit_nodes (
51
52
self , * nodes : Union [ast .AST , Iterable [ast .AST ]], generic : bool = False
@@ -62,7 +63,16 @@ def visit_nodes(
62
63
visit (node )
63
64
64
65
def error (self , error : str , lineno : int , col : int , * args : Any , ** kwargs : Any ):
65
- self .problems .append (make_error (error , lineno , col , * args , ** kwargs ))
66
+ self ._problems .append (make_error (error , lineno , col , * args , ** kwargs ))
67
+
68
+ def get_state (self , * attrs : str ) -> Dict [str , Any ]:
69
+ if not attrs :
70
+ attrs = tuple (self .__dict__ .keys ())
71
+ return {attr : getattr (self , attr ) for attr in attrs if attr != "_problems" }
72
+
73
+ def set_state (self , attrs : Dict [str , Any ]):
74
+ for attr , value in attrs .items ():
75
+ setattr (self , attr , value )
66
76
67
77
68
78
class TrioScope :
@@ -87,8 +97,6 @@ def __init__(self, node: ast.Call, funcname: str, packagename: str):
87
97
88
98
def __str__ (self ):
89
99
# Not supporting other ways of importing trio
90
- # if self.packagename is None:
91
- # return self.funcname
92
100
return f"{ self .packagename } .{ self .funcname } "
93
101
94
102
@@ -100,7 +108,6 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
100
108
and node .func .value .id == "trio"
101
109
and node .func .attr in names
102
110
):
103
- # return "trio." + node.func.attr
104
111
return TrioScope (node , node .func .attr , node .func .value .id )
105
112
return None
106
113
@@ -124,7 +131,7 @@ def __init__(self):
124
131
def visit_With (self , node : Union [ast .With , ast .AsyncWith ]):
125
132
self .check_for_trio100 (node )
126
133
127
- outer_yie = self ._yield_is_error
134
+ outer = self .get_state ( " _yield_is_error" )
128
135
129
136
# Check for a `with trio.<scope_creater>`
130
137
if not self ._safe_decorator :
@@ -139,13 +146,13 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
139
146
self .generic_visit (node )
140
147
141
148
# reset yield_is_error
142
- self ._yield_is_error = outer_yie
149
+ self .set_state ( outer )
143
150
144
151
def visit_AsyncWith (self , node : ast .AsyncWith ):
145
152
self .visit_With (node )
146
153
147
154
def visit_FunctionDef (self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]):
148
- outer = self ._safe_decorator , self . _yield_is_error
155
+ outer = self .get_state ()
149
156
self ._yield_is_error = False
150
157
151
158
# check for @<context_manager_name> and @<library>.<context_manager_name>
@@ -154,14 +161,14 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
154
161
155
162
self .generic_visit (node )
156
163
157
- self ._safe_decorator , self . _yield_is_error = outer
164
+ self .set_state ( outer )
158
165
159
166
def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ):
160
167
self .visit_FunctionDef (node )
161
168
162
169
def visit_Yield (self , node : ast .Yield ):
163
170
if self ._yield_is_error :
164
- self .problems . append ( make_error ( TRIO101 , node .lineno , node .col_offset ) )
171
+ self .error ( TRIO101 , node .lineno , node .col_offset )
165
172
166
173
self .generic_visit (node )
167
174
@@ -173,19 +180,17 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
173
180
isinstance (x , checkpoint_node_types ) and x != node
174
181
for x in ast .walk (node )
175
182
):
176
- self .problems .append (
177
- make_error (TRIO100 , item .lineno , item .col_offset , call )
178
- )
183
+ self .error (TRIO100 , item .lineno , item .col_offset , call )
179
184
180
185
def visit_ImportFrom (self , node : ast .ImportFrom ):
181
186
if node .module == "trio" :
182
- self .problems . append ( make_error ( TRIO106 , node .lineno , node .col_offset ) )
187
+ self .error ( TRIO106 , node .lineno , node .col_offset )
183
188
self .generic_visit (node )
184
189
185
190
def visit_Import (self , node : ast .Import ):
186
191
for name in node .names :
187
192
if name .name == "trio" and name .asname is not None :
188
- self .problems . append ( make_error ( TRIO106 , node .lineno , node .col_offset ) )
193
+ self .error ( TRIO106 , node .lineno , node .col_offset )
189
194
190
195
191
196
def critical_except (node : ast .ExceptHandler ) -> Optional [Tuple [int , int , str ]]:
@@ -239,9 +244,7 @@ def visit_Await(
239
244
cm .has_timeout and cm .shielded for cm in self ._trio_context_managers
240
245
)
241
246
):
242
- self .problems .append (
243
- make_error (TRIO102 , node .lineno , node .col_offset , * self ._critical_scope )
244
- )
247
+ self .error (TRIO102 , node .lineno , node .col_offset , * self ._critical_scope )
245
248
if visit_children :
246
249
self .generic_visit (node )
247
250
@@ -275,14 +278,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
275
278
self .visit_With (node )
276
279
277
280
def visit_FunctionDef (self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]):
278
- outer_cm = self ._safe_decorator
281
+ outer = self .get_state ( " _safe_decorator" )
279
282
280
283
# check for @<context_manager_name> and @<library>.<context_manager_name>
281
284
if has_decorator (node .decorator_list , * context_manager_names ):
282
285
self ._safe_decorator = True
283
286
284
287
self .generic_visit (node )
285
- self ._safe_decorator = outer_cm
288
+
289
+ self .set_state (outer )
286
290
287
291
visit_AsyncFunctionDef = visit_FunctionDef
288
292
@@ -292,13 +296,13 @@ def critical_visit(
292
296
block : Tuple [int , int , str ],
293
297
generic : bool = False ,
294
298
):
295
- outer = self ._critical_scope , self . _trio_context_managers
299
+ outer = self .get_state ( " _critical_scope" , " _trio_context_managers" )
296
300
297
301
self ._trio_context_managers = []
298
302
self ._critical_scope = block
299
303
300
304
self .visit_nodes (node , generic = generic )
301
- self ._critical_scope , self . _trio_context_managers = outer
305
+ self .set_state ( outer )
302
306
303
307
def visit_Try (self , node : ast .Try ):
304
308
# There's no visit_Finally, so we need to manually visit the Try fields.
@@ -345,7 +349,7 @@ def __init__(self):
345
349
# then there might be a code path that doesn't re-raise.
346
350
def visit_ExceptHandler (self , node : ast .ExceptHandler ):
347
351
348
- outer = ( self .unraised , self . except_name , self . loop_depth )
352
+ outer = self .get_state ( )
349
353
marker = critical_except (node )
350
354
351
355
# we need to *not* unset self.unraised if this is non-critical, to still
@@ -362,10 +366,9 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler):
362
366
self .generic_visit (node )
363
367
364
368
if self .unraised and marker is not None :
365
- # print(marker)
366
- self .problems .append (make_error (TRIO103 , * marker ))
369
+ self .error (TRIO103 , * marker )
367
370
368
- ( self .unraised , self . except_name , self . loop_depth ) = outer
371
+ self .set_state ( outer )
369
372
370
373
def visit_Raise (self , node : ast .Raise ):
371
374
# if there's an unraised critical exception, the raise isn't bare,
@@ -375,7 +378,7 @@ def visit_Raise(self, node: ast.Raise):
375
378
and node .exc is not None
376
379
and not (isinstance (node .exc , ast .Name ) and node .exc .id == self .except_name )
377
380
):
378
- self .problems . append ( make_error ( TRIO104 , node .lineno , node .col_offset ) )
381
+ self .error ( TRIO104 , node .lineno , node .col_offset )
379
382
380
383
# treat it as safe regardless, to avoid unnecessary error messages.
381
384
self .unraised = False
@@ -385,7 +388,7 @@ def visit_Raise(self, node: ast.Raise):
385
388
def visit_Return (self , node : Union [ast .Return , ast .Yield ]):
386
389
if self .unraised :
387
390
# Error: must re-raise
388
- self .problems . append ( make_error ( TRIO104 , node .lineno , node .col_offset ) )
391
+ self .error ( TRIO104 , node .lineno , node .col_offset )
389
392
self .generic_visit (node )
390
393
391
394
visit_Yield = visit_Return
@@ -434,20 +437,22 @@ def visit_If(self, node: ast.If):
434
437
# we completely disregard them when checking coverage by resetting the
435
438
# effects of them afterwards
436
439
def visit_For (self , node : Union [ast .For , ast .While ]):
437
- outer_unraised = self .unraised
440
+ outer = self .get_state ("unraised" )
441
+
438
442
self .loop_depth += 1
439
443
for n in node .body :
440
444
self .visit (n )
441
445
self .loop_depth -= 1
442
446
for n in node .orelse :
443
447
self .visit (n )
444
- self .unraised = outer_unraised
448
+
449
+ self .set_state (outer )
445
450
446
451
visit_While = visit_For
447
452
448
453
def visit_Break (self , node : Union [ast .Break , ast .Continue ]):
449
454
if self .unraised and self .loop_depth == 0 :
450
- self .problems . append ( make_error ( TRIO104 , node .lineno , node .col_offset ) )
455
+ self .error ( TRIO104 , node .lineno , node .col_offset )
451
456
self .generic_visit (node )
452
457
453
458
visit_Continue = visit_Break
@@ -492,9 +497,7 @@ def visit_Call(self, node: ast.Call):
492
497
or not isinstance (self .node_stack [- 2 ], ast .Await )
493
498
)
494
499
):
495
- self .problems .append (
496
- make_error (TRIO105 , node .lineno , node .col_offset , node .func .attr )
497
- )
500
+ self .error (TRIO105 , node .lineno , node .col_offset , node .func .attr )
498
501
self .generic_visit (node )
499
502
500
503
@@ -615,7 +618,7 @@ def from_filename(cls, filename: str) -> "Plugin":
615
618
source = f .read ()
616
619
return cls (ast .parse (source ))
617
620
618
- def run (self ) -> Generator [ Tuple [ int , int , str , Type [ Any ]], None , None ]:
621
+ def run (self ) -> Iterable [ Error ]:
619
622
for v in Flake8TrioVisitor .__subclasses__ ():
620
623
yield from v .run (self ._tree )
621
624
@@ -625,7 +628,7 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
625
628
TRIO102 = "TRIO102: await inside {2} on line {0} must have shielded cancel scope with a timeout"
626
629
TRIO103 = "TRIO103: {} block with a code path that doesn't re-raise the error"
627
630
TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised"
628
- TRIO105 = "TRIO105: Trio async function {} must be immediately awaited"
631
+ TRIO105 = "TRIO105: trio async function {} must be immediately awaited"
629
632
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
630
633
TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised"
631
634
TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."
0 commit comments