Skip to content

Commit ff2d0b3

Browse files
authored
Merge pull request #17 from jakkdl/7_async_iterable_checkpoints
2 parents dc2cc8a + 71a66b4 commit ff2d0b3

File tree

5 files changed

+1002
-152
lines changed

5 files changed

+1002
-152
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Changelog
22
*[CalVer, YY.month.patch](https://calver.org/)*
33

4+
## 22.8.2
5+
- Merged TRIO108 into TRIO107
6+
- TRIO108 now handles checkpointing in async iterators
7+
48
## 22.8.1
59
- Added TRIO109: Async definitions should not have a `timeout` parameter. Use `trio.[fail/move_on]_[at/after]`
610
- Added TRIO110: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ pip install flake8-trio
2828
- **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception.
2929
- **TRIO105**: Calling a trio async function without immediately `await`ing it.
3030
- **TRIO106**: trio must be imported with `import trio` for the linter to work.
31-
- **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised.
32-
- **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
33-
Checkpoints are `await`, `async with` `async for`.
31+
- **TRIO107**: exit or `return` from async function with no guaranteed checkpoint or exception since function definition.
32+
- **TRIO108**: exit, yield or return from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition)
33+
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
3434
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
3535
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.

flake8_trio.py

Lines changed: 211 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, Union
1515

1616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17-
__version__ = "22.8.1"
17+
__version__ = "22.8.2"
1818

1919

2020
class Statement(NamedTuple):
@@ -35,8 +35,8 @@ class Statement(NamedTuple):
3535
"TRIO104": "Cancelled (and therefore BaseException) must be re-raised",
3636
"TRIO105": "trio async function {} must be immediately awaited",
3737
"TRIO106": "trio must be imported with `import trio` for the linter to work",
38-
"TRIO107": "Async functions must have at least one checkpoint on every code path, unless an exception is raised",
39-
"TRIO108": "Early return from async function must have at least one checkpoint on every code path before it.",
38+
"TRIO107": "{0} from async function with no guaranteed checkpoint or exception since function definition on line {1.lineno}",
39+
"TRIO108": "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno}",
4040
"TRIO109": "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead",
4141
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
4242
}
@@ -68,6 +68,7 @@ class Flake8TrioVisitor(ast.NodeVisitor):
6868
def __init__(self):
6969
super().__init__()
7070
self._problems: List[Error] = []
71+
self.suppress_errors = False
7172

7273
@classmethod
7374
def run(cls, tree: ast.AST) -> Iterable[Error]:
@@ -90,9 +91,10 @@ def visit_nodes(
9091
visit(node)
9192

9293
def error(self, error: str, node: HasLineInfo, *args: Any, **kwargs: Any):
93-
self._problems.append(
94-
make_error(error, node.lineno, node.col_offset, *args, **kwargs)
95-
)
94+
if not self.suppress_errors:
95+
self._problems.append(
96+
make_error(error, node.lineno, node.col_offset, *args, **kwargs)
97+
)
9698

9799
def get_state(self, *attrs: str) -> Dict[str, Any]:
98100
if not attrs:
@@ -103,6 +105,10 @@ def set_state(self, attrs: Dict[str, Any]):
103105
for attr, value in attrs.items():
104106
setattr(self, attr, value)
105107

108+
def walk(self, *body: ast.AST) -> Iterable[ast.AST]:
109+
for b in body:
110+
yield from ast.walk(b)
111+
106112

107113
class TrioScope:
108114
def __init__(self, node: ast.Call, funcname: str, packagename: str):
@@ -561,105 +567,251 @@ def visit_Call(self, node: ast.Call):
561567
class Visitor107_108(Flake8TrioVisitor):
562568
def __init__(self):
563569
super().__init__()
564-
self.all_await = True
570+
self.yield_count = 0
571+
572+
self.always_checkpoint: Optional[Statement] = None
573+
self.checkpoint_continue: Optional[Statement] = None
574+
self.checkpoint_break: Optional[Statement] = None
575+
576+
self.default = self.get_state()
565577

566578
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
567-
outer = self.all_await
579+
if has_decorator(node.decorator_list, "overload"):
580+
return
581+
582+
outer = self.get_state()
583+
self.set_state(self.default)
584+
585+
self.always_checkpoint = Statement("function definition", node.lineno)
568586

569-
# do not require checkpointing if overloading
570-
self.all_await = has_decorator(node.decorator_list, "overload")
571587
self.generic_visit(node)
588+
self.check_function_exit(node)
572589

573-
if not self.all_await:
574-
self.error("TRIO107", node)
590+
self.set_state(outer)
575591

576-
self.all_await = outer
592+
def check_function_exit(self, node: Union[ast.Return, ast.AsyncFunctionDef]):
593+
# error if function exits w/o guaranteed checkpoint since function entry
594+
method = "return" if isinstance(node, ast.Return) else "exit"
595+
596+
if self.always_checkpoint is not None:
597+
if self.yield_count:
598+
self.error("TRIO108", node, method, self.always_checkpoint)
599+
else:
600+
self.error("TRIO107", node, method, self.always_checkpoint)
577601

578602
def visit_Return(self, node: ast.Return):
579603
self.generic_visit(node)
580-
if not self.all_await:
581-
self.error("TRIO108", node)
604+
self.check_function_exit(node)
605+
582606
# avoid duplicate error messages
583-
self.all_await = True
607+
self.always_checkpoint = None
584608

585-
# disregard raise's in nested functions
609+
# disregard checkpoints in nested function definitions
586610
def visit_FunctionDef(self, node: ast.FunctionDef):
587-
outer = self.all_await
611+
outer = self.get_state()
612+
self.set_state(self.default)
588613
self.generic_visit(node)
589-
self.all_await = outer
614+
self.set_state(outer)
590615

591616
# checkpoint functions
592-
def visit_Await(
593-
self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise]
594-
):
617+
def visit_Await(self, node: Union[ast.Await, ast.Raise]):
618+
# the expression being awaited is not checkpointed
619+
# so only set checkpoint after the await node
595620
self.generic_visit(node)
596-
self.all_await = True
597-
598-
visit_AsyncFor = visit_Await
599-
visit_AsyncWith = visit_Await
621+
self.always_checkpoint = None
600622

601623
# raising exception means we don't need to checkpoint so we can treat it as one
602624
visit_Raise = visit_Await
603625

604-
# valid checkpoint if there's valid checkpoints (or raise) in at least one of:
605-
# (try or else) and all excepts
606-
# finally
626+
# guaranteed to checkpoint on at least one of enter and exit
627+
# if it checkpoints on entry and there's a yield in it, we can't treat it as checkpoint
628+
# but it may not checkpoint on entry, so yields inside need to raise problem
629+
def visit_AsyncWith(self, node: ast.AsyncWith):
630+
self.visit_nodes(node.items)
631+
prebody_yield_count = self.yield_count
632+
633+
# there's no guarantee of checkpoint before entry
634+
self.visit_nodes(node.body)
635+
636+
# no yield in body, treat as checkpoint
637+
if prebody_yield_count == self.yield_count:
638+
self.always_checkpoint = None
639+
640+
# error if no checkpoint since earlier yield or function entry
641+
def visit_Yield(self, node: ast.Yield):
642+
self.generic_visit(node)
643+
self.yield_count += 1
644+
if self.always_checkpoint is not None:
645+
self.error("TRIO108", node, "yield", self.always_checkpoint)
646+
647+
# mark as requiring checkpoint after
648+
self.always_checkpoint = Statement("yield", node.lineno)
649+
650+
# valid checkpoint if there's valid checkpoints (or raise) in:
651+
# (try or else) and all excepts, or in finally
652+
#
653+
# try can jump into any except or into the finally* at any point during it's
654+
# execution so we need to make sure except & finally can handle worst-case
655+
# * unless there's a bare except / except BaseException - not implemented.
607656
def visit_Try(self, node: ast.Try):
608-
if self.all_await:
609-
self.generic_visit(node)
610-
return
657+
# except & finally guaranteed to enter with checkpoint if checkpointed
658+
# before try and no yield in try body.
659+
body_always_checkpoint = self.always_checkpoint
660+
for inner_node in self.walk(*node.body):
661+
if isinstance(inner_node, ast.Yield):
662+
body_always_checkpoint = Statement("yield", inner_node.lineno)
663+
break
611664

612665
# check try body
613666
self.visit_nodes(node.body)
614-
body_await = self.all_await
615-
self.all_await = False
667+
668+
# save state at end of try for entering else
669+
try_checkpoint = self.always_checkpoint
616670

617671
# check that all except handlers checkpoint (await or most likely raise)
618-
all_except_await = True
672+
all_except_checkpoint: Optional[Statement] = None
619673
for handler in node.handlers:
674+
# enter with worst case of try
675+
self.always_checkpoint = body_always_checkpoint
676+
620677
self.visit_nodes(handler)
621-
all_except_await &= self.all_await
622-
self.all_await = False
678+
679+
if self.always_checkpoint is not None:
680+
all_except_checkpoint = self.always_checkpoint
623681

624682
# check else
683+
# if else runs it's after all of try, so restore state to back then
684+
self.always_checkpoint = try_checkpoint
625685
self.visit_nodes(node.orelse)
626686

627-
# (try or else) and all excepts
628-
self.all_await = (body_await or self.all_await) and all_except_await
687+
# checkpoint if else checkpoints, and all excepts checkpoint
688+
if all_except_checkpoint is not None:
689+
self.always_checkpoint = all_except_checkpoint
629690

630-
# finally can check on it's own
631-
self.visit_nodes(node.finalbody)
691+
# if there's no finally, don't restore state from try
692+
if node.finalbody:
693+
# can enter from try, else, or any except
694+
if body_always_checkpoint is not None:
695+
self.always_checkpoint = body_always_checkpoint
696+
self.visit_nodes(node.finalbody)
632697

633-
# valid checkpoint if both body and orelse have checkpoints
698+
# valid checkpoint if both body and orelse checkpoint
634699
def visit_If(self, node: Union[ast.If, ast.IfExp]):
635-
if self.all_await:
636-
self.generic_visit(node)
637-
return
638-
639-
# ignore checkpoints in condition
700+
# visit condition
640701
self.visit_nodes(node.test)
641-
self.all_await = False
702+
outer = self.get_state("always_checkpoint")
642703

643-
# check body
704+
# visit body
644705
self.visit_nodes(node.body)
645-
body_await = self.all_await
646-
self.all_await = False
706+
body_outer = self.get_state("always_checkpoint")
647707

708+
# reset to after condition and visit orelse
709+
self.set_state(outer)
648710
self.visit_nodes(node.orelse)
649711

650-
# checkpoint if both body and else
651-
self.all_await = body_await and self.all_await
712+
# if body failed, reset to that state
713+
if body_outer["always_checkpoint"] is not None:
714+
self.set_state(body_outer)
715+
716+
# otherwise keep state (fail or not) as it was after orelse
652717

653718
# inline if
654719
visit_IfExp = visit_If
655720

656-
# ignore checkpoints in loops due to continue/break shenanigans
657-
def visit_While(self, node: Union[ast.While, ast.For]):
658-
outer = self.all_await
659-
self.generic_visit(node)
660-
self.all_await = outer
721+
# Check for yields w/o checkpoint inbetween due to entering loop body the first time,
722+
# after completing all of loop body, and after any continues.
723+
# yield in else have same requirement
724+
# state after the loop same as above, and in addition the state at any break
725+
def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
726+
# save state in case of nested loops
727+
outer = self.get_state(
728+
"checkpoint_continue", "checkpoint_break", "suppress_errors"
729+
)
730+
731+
# visit condition
732+
if isinstance(node, ast.While):
733+
self.visit_nodes(node.test)
734+
else:
735+
self.visit_nodes(node.target)
736+
self.visit_nodes(node.iter)
737+
738+
self.checkpoint_continue = None
739+
pre_body_always_checkpoint = self.always_checkpoint
740+
741+
# AsyncFor guaranteed checkpoint at every iteration
742+
if isinstance(node, ast.AsyncFor):
743+
pre_body_always_checkpoint = None
744+
self.always_checkpoint = None
745+
746+
# if we normally enter loop with checkpoint, check for worst-case start of loop
747+
# due to `continue` or multiple iterations
748+
elif self.always_checkpoint is None:
749+
# silently check if body unsets yield
750+
# so we later can check if body errors out on worst case of entering
751+
self.suppress_errors = True
752+
753+
# self.checkpoint_continue is set to False if loop body ever does
754+
# continue with self.always_checkpoint == False
755+
self.visit_nodes(node.body)
756+
757+
self.suppress_errors = outer["suppress_errors"]
758+
759+
if self.checkpoint_continue is not None:
760+
self.always_checkpoint = self.checkpoint_continue
761+
762+
self.checkpoint_break = None
763+
self.visit_nodes(node.body)
764+
765+
# AsyncFor guarantees checkpoint on running out of iterable
766+
# so reset checkpoint state at end of loop. (but not state at break)
767+
if isinstance(node, ast.AsyncFor):
768+
self.always_checkpoint = None
769+
else:
770+
# enter orelse with worst case:
771+
# loop body might execute fully before entering orelse
772+
# (current state of self.always_checkpoint)
773+
# or not at all
774+
if pre_body_always_checkpoint is not None:
775+
self.always_checkpoint = pre_body_always_checkpoint
776+
# or at a continue
777+
elif self.checkpoint_continue is not None:
778+
self.always_checkpoint = self.checkpoint_continue
779+
780+
# visit orelse
781+
self.visit_nodes(node.orelse)
782+
783+
# We may exit from:
784+
# orelse (which covers no body, body until continue, and all body)
785+
# break
786+
if self.checkpoint_break is not None:
787+
self.always_checkpoint = self.checkpoint_break
788+
789+
# reset state in case of nested loops
790+
self.set_state(outer)
661791

662-
visit_For = visit_While
792+
visit_While = visit_loop
793+
visit_For = visit_loop
794+
visit_AsyncFor = visit_loop
795+
796+
# save state in case of continue/break at a point not guaranteed to checkpoint
797+
def visit_Continue(self, node: ast.Continue):
798+
if self.always_checkpoint is not None:
799+
self.checkpoint_continue = self.always_checkpoint
800+
801+
def visit_Break(self, node: ast.Break):
802+
if self.always_checkpoint is not None:
803+
self.checkpoint_break = self.always_checkpoint
804+
805+
# first node in a condition is guaranteed to run, but may shortcut so checkpoints
806+
# in remaining nodes are not guaranteed
807+
# Not fully implemented: worst case shortcut with yields in condition
808+
def visit_BoolOp(self, node: ast.BoolOp):
809+
self.visit(node.op)
810+
self.visit_nodes(node.values[:1])
811+
outer = self.always_checkpoint
812+
self.visit_nodes(node.values[1:])
813+
814+
self.always_checkpoint = outer
663815

664816

665817
class Plugin:

0 commit comments

Comments
 (0)