Skip to content

Commit 38df5df

Browse files
committed
various small fixes and improved test output
1 parent b53e5fa commit 38df5df

File tree

6 files changed

+123
-76
lines changed

6 files changed

+123
-76
lines changed

flake8_trio.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111

1212
import ast
1313
import tokenize
14-
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union
14+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
1515

1616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
1717
__version__ = "22.7.6"
1818

19-
2019
Error = Tuple[int, int, str, Type[Any]]
20+
21+
2122
checkpoint_node_types = (ast.Await, ast.AsyncFor, ast.AsyncWith)
2223
cancel_scope_names = (
2324
"fail_after",
@@ -39,13 +40,13 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
3940
class Flake8TrioVisitor(ast.NodeVisitor):
4041
def __init__(self):
4142
super().__init__()
42-
self.problems: List[Error] = []
43+
self._problems: List[Error] = []
4344

4445
@classmethod
45-
def run(cls, tree: ast.AST) -> Generator[Error, None, None]:
46+
def run(cls, tree: ast.AST) -> Iterable[Error]:
4647
visitor = cls()
4748
visitor.visit(tree)
48-
yield from visitor.problems
49+
yield from visitor._problems
4950

5051
def visit_nodes(
5152
self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False
@@ -62,7 +63,16 @@ def visit_nodes(
6263
visit(node)
6364

6465
def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any):
65-
self.problems.append(make_error(error, lineno, col, *args, **kwargs))
66+
self._problems.append(make_error(error, lineno, col, *args, **kwargs))
67+
68+
def get_state(self, *attrs: str) -> Dict[str, Any]:
69+
if not attrs:
70+
attrs = tuple(self.__dict__.keys())
71+
return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
72+
73+
def set_state(self, attrs: Dict[str, Any]):
74+
for attr, value in attrs.items():
75+
setattr(self, attr, value)
6676

6777

6878
class TrioScope:
@@ -87,8 +97,6 @@ def __init__(self, node: ast.Call, funcname: str, packagename: str):
8797

8898
def __str__(self):
8999
# Not supporting other ways of importing trio
90-
# if self.packagename is None:
91-
# return self.funcname
92100
return f"{self.packagename}.{self.funcname}"
93101

94102

@@ -100,7 +108,6 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
100108
and node.func.value.id == "trio"
101109
and node.func.attr in names
102110
):
103-
# return "trio." + node.func.attr
104111
return TrioScope(node, node.func.attr, node.func.value.id)
105112
return None
106113

@@ -124,7 +131,7 @@ def __init__(self):
124131
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
125132
self.check_for_trio100(node)
126133

127-
outer_yie = self._yield_is_error
134+
outer = self.get_state("_yield_is_error")
128135

129136
# Check for a `with trio.<scope_creater>`
130137
if not self._safe_decorator:
@@ -139,13 +146,13 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
139146
self.generic_visit(node)
140147

141148
# reset yield_is_error
142-
self._yield_is_error = outer_yie
149+
self.set_state(outer)
143150

144151
def visit_AsyncWith(self, node: ast.AsyncWith):
145152
self.visit_With(node)
146153

147154
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
148-
outer = self._safe_decorator, self._yield_is_error
155+
outer = self.get_state()
149156
self._yield_is_error = False
150157

151158
# check for @<context_manager_name> and @<library>.<context_manager_name>
@@ -154,14 +161,14 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
154161

155162
self.generic_visit(node)
156163

157-
self._safe_decorator, self._yield_is_error = outer
164+
self.set_state(outer)
158165

159166
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
160167
self.visit_FunctionDef(node)
161168

162169
def visit_Yield(self, node: ast.Yield):
163170
if self._yield_is_error:
164-
self.problems.append(make_error(TRIO101, node.lineno, node.col_offset))
171+
self.error(TRIO101, node.lineno, node.col_offset)
165172

166173
self.generic_visit(node)
167174

@@ -173,19 +180,17 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
173180
isinstance(x, checkpoint_node_types) and x != node
174181
for x in ast.walk(node)
175182
):
176-
self.problems.append(
177-
make_error(TRIO100, item.lineno, item.col_offset, call)
178-
)
183+
self.error(TRIO100, item.lineno, item.col_offset, call)
179184

180185
def visit_ImportFrom(self, node: ast.ImportFrom):
181186
if node.module == "trio":
182-
self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
187+
self.error(TRIO106, node.lineno, node.col_offset)
183188
self.generic_visit(node)
184189

185190
def visit_Import(self, node: ast.Import):
186191
for name in node.names:
187192
if name.name == "trio" and name.asname is not None:
188-
self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
193+
self.error(TRIO106, node.lineno, node.col_offset)
189194

190195

191196
def critical_except(node: ast.ExceptHandler) -> Optional[Tuple[int, int, str]]:
@@ -239,9 +244,7 @@ def visit_Await(
239244
cm.has_timeout and cm.shielded for cm in self._trio_context_managers
240245
)
241246
):
242-
self.problems.append(
243-
make_error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
244-
)
247+
self.error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
245248
if visit_children:
246249
self.generic_visit(node)
247250

@@ -275,14 +278,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
275278
self.visit_With(node)
276279

277280
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
278-
outer_cm = self._safe_decorator
281+
outer = self.get_state("_safe_decorator")
279282

280283
# check for @<context_manager_name> and @<library>.<context_manager_name>
281284
if has_decorator(node.decorator_list, *context_manager_names):
282285
self._safe_decorator = True
283286

284287
self.generic_visit(node)
285-
self._safe_decorator = outer_cm
288+
289+
self.set_state(outer)
286290

287291
visit_AsyncFunctionDef = visit_FunctionDef
288292

@@ -292,13 +296,13 @@ def critical_visit(
292296
block: Tuple[int, int, str],
293297
generic: bool = False,
294298
):
295-
outer = self._critical_scope, self._trio_context_managers
299+
outer = self.get_state("_critical_scope", "_trio_context_managers")
296300

297301
self._trio_context_managers = []
298302
self._critical_scope = block
299303

300304
self.visit_nodes(node, generic=generic)
301-
self._critical_scope, self._trio_context_managers = outer
305+
self.set_state(outer)
302306

303307
def visit_Try(self, node: ast.Try):
304308
# There's no visit_Finally, so we need to manually visit the Try fields.
@@ -345,7 +349,7 @@ def __init__(self):
345349
# then there might be a code path that doesn't re-raise.
346350
def visit_ExceptHandler(self, node: ast.ExceptHandler):
347351

348-
outer = (self.unraised, self.except_name, self.loop_depth)
352+
outer = self.get_state()
349353
marker = critical_except(node)
350354

351355
# we need to *not* unset self.unraised if this is non-critical, to still
@@ -362,10 +366,9 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler):
362366
self.generic_visit(node)
363367

364368
if self.unraised and marker is not None:
365-
# print(marker)
366-
self.problems.append(make_error(TRIO103, *marker))
369+
self.error(TRIO103, *marker)
367370

368-
(self.unraised, self.except_name, self.loop_depth) = outer
371+
self.set_state(outer)
369372

370373
def visit_Raise(self, node: ast.Raise):
371374
# if there's an unraised critical exception, the raise isn't bare,
@@ -375,7 +378,7 @@ def visit_Raise(self, node: ast.Raise):
375378
and node.exc is not None
376379
and not (isinstance(node.exc, ast.Name) and node.exc.id == self.except_name)
377380
):
378-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
381+
self.error(TRIO104, node.lineno, node.col_offset)
379382

380383
# treat it as safe regardless, to avoid unnecessary error messages.
381384
self.unraised = False
@@ -385,7 +388,7 @@ def visit_Raise(self, node: ast.Raise):
385388
def visit_Return(self, node: Union[ast.Return, ast.Yield]):
386389
if self.unraised:
387390
# Error: must re-raise
388-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
391+
self.error(TRIO104, node.lineno, node.col_offset)
389392
self.generic_visit(node)
390393

391394
visit_Yield = visit_Return
@@ -434,20 +437,22 @@ def visit_If(self, node: ast.If):
434437
# we completely disregard them when checking coverage by resetting the
435438
# effects of them afterwards
436439
def visit_For(self, node: Union[ast.For, ast.While]):
437-
outer_unraised = self.unraised
440+
outer = self.get_state("unraised")
441+
438442
self.loop_depth += 1
439443
for n in node.body:
440444
self.visit(n)
441445
self.loop_depth -= 1
442446
for n in node.orelse:
443447
self.visit(n)
444-
self.unraised = outer_unraised
448+
449+
self.set_state(outer)
445450

446451
visit_While = visit_For
447452

448453
def visit_Break(self, node: Union[ast.Break, ast.Continue]):
449454
if self.unraised and self.loop_depth == 0:
450-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
455+
self.error(TRIO104, node.lineno, node.col_offset)
451456
self.generic_visit(node)
452457

453458
visit_Continue = visit_Break
@@ -492,9 +497,7 @@ def visit_Call(self, node: ast.Call):
492497
or not isinstance(self.node_stack[-2], ast.Await)
493498
)
494499
):
495-
self.problems.append(
496-
make_error(TRIO105, node.lineno, node.col_offset, node.func.attr)
497-
)
500+
self.error(TRIO105, node.lineno, node.col_offset, node.func.attr)
498501
self.generic_visit(node)
499502

500503

@@ -615,7 +618,7 @@ def from_filename(cls, filename: str) -> "Plugin":
615618
source = f.read()
616619
return cls(ast.parse(source))
617620

618-
def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
621+
def run(self) -> Iterable[Error]:
619622
for v in Flake8TrioVisitor.__subclasses__():
620623
yield from v.run(self._tree)
621624

@@ -625,7 +628,7 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
625628
TRIO102 = "TRIO102: await inside {2} on line {0} must have shielded cancel scope with a timeout"
626629
TRIO103 = "TRIO103: {} block with a code path that doesn't re-raise the error"
627630
TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised"
628-
TRIO105 = "TRIO105: Trio async function {} must be immediately awaited"
631+
TRIO105 = "TRIO105: trio async function {} must be immediately awaited"
629632
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
630633
TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised"
631634
TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."

0 commit comments

Comments
 (0)