Skip to content

Commit 929e5f3

Browse files
committed
Merge remote-tracking branch 'origin/main' into 7_async_iterable_checkpoints
2 parents c2621a4 + b298954 commit 929e5f3

File tree

6 files changed

+213
-31
lines changed

6 files changed

+213
-31
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
## Future
55
- Extend TRIO107 and 108 to also handle yields
66

7+
## 22.8.1
8+
- Added TRIO109: Async definitions should not have a `timeout` parameter. Use `trio.[fail/move_on]_[at/after]`
9+
- Added TRIO110: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
10+
711
## 22.7.6
812
- Extend TRIO102 to also check inside `except BaseException` and `except trio.Cancelled`
913
- Extend TRIO104 to also check for `yield`

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ pip install flake8-trio
2828
- **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception.
2929
- **TRIO105**: Calling a trio async function without immediately `await`ing it.
3030
- **TRIO106**: trio must be imported with `import trio` for the linter to work.
31+
<!-- TODO: fix
3132
- **TRIO107**: Async functions and iterables must have at least one checkpoint on every code path, unless an exception is raised.
3233
- **TRIO108**: `return` or `yield` from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
33-
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
34+
# - **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised.
35+
# - **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
36+
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
37+
-->
38+
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
39+
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.

flake8_trio.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
1515

1616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17+
<<<<<<< HEAD
1718
__version__ = "22.7.6"
19+
=======
20+
__version__ = "22.8.1"
21+
>>>>>>> origin/main
1822

1923
Error = Tuple[int, int, str, Type[Any]]
2024

@@ -40,14 +44,14 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
4044
class Flake8TrioVisitor(ast.NodeVisitor):
4145
def __init__(self):
4246
super().__init__()
43-
self.problems: List[Error] = []
47+
self._problems: List[Error] = []
4448
self.suppress_errors = False
4549

4650
@classmethod
4751
def run(cls, tree: ast.AST) -> Iterable[Error]:
4852
visitor = cls()
4953
visitor.visit(tree)
50-
yield from visitor.problems
54+
yield from visitor._problems
5155

5256
def visit_nodes(
5357
self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False
@@ -65,12 +69,12 @@ def visit_nodes(
6569

6670
def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any):
6771
if not self.suppress_errors:
68-
self.problems.append(make_error(error, lineno, col, *args, **kwargs))
72+
self._problems.append(make_error(error, lineno, col, *args, **kwargs))
6973

7074
def get_state(self, *attrs: str) -> Dict[str, Any]:
7175
if not attrs:
7276
attrs = tuple(self.__dict__.keys())
73-
return {attr: getattr(self, attr) for attr in attrs if attr != "problems"}
77+
return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
7478

7579
def set_state(self, attrs: Dict[str, Any]):
7680
for attr, value in attrs.items():
@@ -158,8 +162,9 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
158162
self.visit_With(node)
159163

160164
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
161-
outer = self.get_state("_safe_decorator", "_yield_is_error")
165+
outer = self.get_state()
162166
self._yield_is_error = False
167+
self._inside_loop = False
163168

164169
# check for @<context_manager_name> and @<library>.<context_manager_name>
165170
if has_decorator(node.decorator_list, *context_manager_names):
@@ -170,11 +175,12 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
170175
self.set_state(outer)
171176

172177
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
178+
self.check_109(node.args)
173179
self.visit_FunctionDef(node)
174180

175181
def visit_Yield(self, node: ast.Yield):
176182
if self._yield_is_error:
177-
self.problems.append(make_error(TRIO101, node.lineno, node.col_offset))
183+
self.error(TRIO101, node.lineno, node.col_offset)
178184

179185
self.generic_visit(node)
180186

@@ -186,19 +192,35 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
186192
isinstance(x, checkpoint_node_types) and x != node
187193
for x in ast.walk(node)
188194
):
189-
self.problems.append(
190-
make_error(TRIO100, item.lineno, item.col_offset, call)
191-
)
195+
self.error(TRIO100, item.lineno, item.col_offset, call)
192196

193197
def visit_ImportFrom(self, node: ast.ImportFrom):
194198
if node.module == "trio":
195-
self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
199+
self.error(TRIO106, node.lineno, node.col_offset)
196200
self.generic_visit(node)
197201

198202
def visit_Import(self, node: ast.Import):
199203
for name in node.names:
200204
if name.name == "trio" and name.asname is not None:
201-
self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
205+
self.error(TRIO106, node.lineno, node.col_offset)
206+
207+
def check_109(self, args: ast.arguments):
208+
for arg in (*args.posonlyargs, *args.args, *args.kwonlyargs):
209+
if arg.arg == "timeout":
210+
self.error(TRIO109, arg.lineno, arg.col_offset)
211+
212+
def visit_While(self, node: ast.While):
213+
self.check_for_110(node)
214+
self.generic_visit(node)
215+
216+
def check_for_110(self, node: ast.While):
217+
if (
218+
len(node.body) == 1
219+
and isinstance(node.body[0], ast.Expr)
220+
and isinstance(node.body[0].value, ast.Await)
221+
and get_trio_scope(node.body[0].value.value, "sleep", "sleep_until")
222+
):
223+
self.error(TRIO110, node.lineno, node.col_offset)
202224

203225

204226
def critical_except(node: ast.ExceptHandler) -> Optional[Tuple[int, int, str]]:
@@ -252,9 +274,7 @@ def visit_Await(
252274
cm.has_timeout and cm.shielded for cm in self._trio_context_managers
253275
)
254276
):
255-
self.problems.append(
256-
make_error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
257-
)
277+
self.error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
258278
if visit_children:
259279
self.generic_visit(node)
260280

@@ -288,14 +308,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
288308
self.visit_With(node)
289309

290310
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
291-
outer_cm = self._safe_decorator
311+
outer = self.get_state("_safe_decorator")
292312

293313
# check for @<context_manager_name> and @<library>.<context_manager_name>
294314
if has_decorator(node.decorator_list, *context_manager_names):
295315
self._safe_decorator = True
296316

297317
self.generic_visit(node)
298-
self._safe_decorator = outer_cm
318+
319+
self.set_state(outer)
299320

300321
visit_AsyncFunctionDef = visit_FunctionDef
301322

@@ -305,13 +326,13 @@ def critical_visit(
305326
block: Tuple[int, int, str],
306327
generic: bool = False,
307328
):
308-
outer = self._critical_scope, self._trio_context_managers
329+
outer = self.get_state("_critical_scope", "_trio_context_managers")
309330

310331
self._trio_context_managers = []
311332
self._critical_scope = block
312333

313334
self.visit_nodes(node, generic=generic)
314-
self._critical_scope, self._trio_context_managers = outer
335+
self.set_state(outer)
315336

316337
def visit_Try(self, node: ast.Try):
317338
# There's no visit_Finally, so we need to manually visit the Try fields.
@@ -358,7 +379,7 @@ def __init__(self):
358379
# then there might be a code path that doesn't re-raise.
359380
def visit_ExceptHandler(self, node: ast.ExceptHandler):
360381

361-
outer = self.get_state("unraised", "except_name", "loop_depth")
382+
outer = self.get_state()
362383
marker = critical_except(node)
363384

364385
# we need to *not* unset self.unraised if this is non-critical, to still
@@ -375,7 +396,7 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler):
375396
self.generic_visit(node)
376397

377398
if self.unraised and marker is not None:
378-
self.problems.append(make_error(TRIO103, *marker))
399+
self.error(TRIO103, *marker)
379400

380401
self.set_state(outer)
381402

@@ -387,7 +408,7 @@ def visit_Raise(self, node: ast.Raise):
387408
and node.exc is not None
388409
and not (isinstance(node.exc, ast.Name) and node.exc.id == self.except_name)
389410
):
390-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
411+
self.error(TRIO104, node.lineno, node.col_offset)
391412

392413
# treat it as safe regardless, to avoid unnecessary error messages.
393414
self.unraised = False
@@ -397,7 +418,7 @@ def visit_Raise(self, node: ast.Raise):
397418
def visit_Return(self, node: Union[ast.Return, ast.Yield]):
398419
if self.unraised:
399420
# Error: must re-raise
400-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
421+
self.error(TRIO104, node.lineno, node.col_offset)
401422
self.generic_visit(node)
402423

403424
visit_Yield = visit_Return
@@ -446,7 +467,7 @@ def visit_If(self, node: ast.If):
446467
# we completely disregard them when checking coverage by resetting the
447468
# effects of them afterwards
448469
def visit_For(self, node: Union[ast.For, ast.While]):
449-
outer_unraised = self.unraised
470+
outer = self.get_state("unraised")
450471

451472
self.loop_depth += 1
452473
for n in node.body:
@@ -455,13 +476,13 @@ def visit_For(self, node: Union[ast.For, ast.While]):
455476
for n in node.orelse:
456477
self.visit(n)
457478

458-
self.unraised = outer_unraised
479+
self.set_state(outer)
459480

460481
visit_While = visit_For
461482

462483
def visit_Break(self, node: Union[ast.Break, ast.Continue]):
463484
if self.unraised and self.loop_depth == 0:
464-
self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
485+
self.error(TRIO104, node.lineno, node.col_offset)
465486
self.generic_visit(node)
466487

467488
visit_Continue = visit_Break
@@ -506,9 +527,7 @@ def visit_Call(self, node: ast.Call):
506527
or not isinstance(self.node_stack[-2], ast.Await)
507528
)
508529
):
509-
self.problems.append(
510-
make_error(TRIO105, node.lineno, node.col_offset, node.func.attr)
511-
)
530+
self.error(TRIO105, node.lineno, node.col_offset, node.func.attr)
512531
self.generic_visit(node)
513532

514533

@@ -779,3 +798,7 @@ def run(self) -> Iterable[Error]:
779798
TRIO108 = (
780799
"TRIO108: {} from async function with no guaranteed checkpoint since {} on line {}"
781800
)
801+
#TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised"
802+
#TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."
803+
TRIO109 = "TRIO109: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead"
804+
TRIO110 = "TRIO110: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`."

tests/test_flake8_trio.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def test_eval(test: str, path: str):
4646
expected: List[Error] = []
4747
with open(os.path.join("tests", path)) as file:
4848
for lineno, line in enumerate(file, start=1):
49-
# get text between `error: ` and newline
50-
k = re.search(r"(?<=error: ).*(?=\n)", line)
49+
# get text between `error:` and end of line
50+
k = re.search(r"(?<=error:).*$", line)
5151
if not k or line.strip()[0] == "#":
5252
continue
5353
# Append a bunch of empty strings so string formatting gives garbage instead
@@ -60,6 +60,7 @@ def test_eval(test: str, path: str):
6060
assert col.isdigit(), f'invalid column "{col}" @L{lineno}, in "{line}"'
6161
expected.append(make_error(error_msg, lineno, int(col), *args))
6262

63+
assert expected, "failed to parse any errors in file"
6364
assert_expected_errors(path, test, *expected)
6465

6566

tests/trio109.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
timeout = 10
2+
3+
4+
async def foo():
5+
...
6+
7+
8+
# args
9+
async def foo_1(timeout): # error: 16
10+
...
11+
12+
13+
# arg in args wih default & annotation
14+
async def foo_2(timeout: int = 3): # error: 16
15+
...
16+
17+
18+
# vararg
19+
async def foo_3(*timeout): # ignored
20+
...
21+
22+
23+
# kwarg
24+
async def foo_4(**timeout): # ignored
25+
...
26+
27+
28+
# correct line/col
29+
async def foo_5(
30+
bar,
31+
timeouts,
32+
my_timeout,
33+
timeout_,
34+
timeout, # error: 4
35+
):
36+
...
37+
38+
39+
# posonlyargs
40+
async def foo_6(
41+
timeout, # error: 4
42+
/,
43+
bar,
44+
):
45+
...
46+
47+
48+
# kwonlyargs
49+
async def foo_7(
50+
*,
51+
timeout, # error: 4
52+
):
53+
...
54+
55+
56+
# kwonlyargs (and kw_defaults)
57+
async def foo_8(
58+
*,
59+
timeout=5, # error: 4
60+
):
61+
...
62+
63+
64+
async def foo_9(k=timeout):
65+
...
66+
67+
68+
# normal functions are not checked
69+
def foo_10(timeout):
70+
...
71+
72+
73+
def foo_11(timeout, /):
74+
...
75+
76+
77+
def foo_12(*, timeout):
78+
...

0 commit comments

Comments
 (0)