Skip to content

Commit 7acf4d6

Browse files
committed
Added details to TRIO108, which also made me clean up the logic a bit.
1 parent 6cde589 commit 7acf4d6

File tree

2 files changed

+196
-91
lines changed

2 files changed

+196
-91
lines changed

flake8_trio.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -507,21 +507,25 @@ def visit_Call(self, node: ast.Call):
507507
self.generic_visit(node)
508508

509509

510+
Checkpoint_Type = Optional[Tuple[str, int]]
511+
512+
510513
class Visitor107_108(Flake8TrioVisitor):
511514
def __init__(self):
512515
super().__init__()
513-
self.always_checkpoint = True
514516
self.yield_count = 0
515-
self.checkpoint_continue = True
516-
self.checkpoint_break: Optional[bool] = None
517+
518+
self.always_checkpoint: Checkpoint_Type = None
519+
self.checkpoint_continue: Checkpoint_Type = None
520+
self.checkpoint_break: Checkpoint_Type = None
517521

518522
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
519523
if has_decorator(node.decorator_list, "overload"):
520524
return
521525

522526
outer = self.get_state()
523527

524-
self.always_checkpoint = False
528+
self.always_checkpoint = ("function definition", node.lineno)
525529
self.safe_decorator = has_decorator(node.decorator_list, *context_manager_names)
526530

527531
self.generic_visit(node)
@@ -531,20 +535,36 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
531535
# we ignore contextmanager decorators with yields since the decorator may
532536
# checkpoint on entry, which the yield then unsets. (and if there's no explicit
533537
# checkpoint before yield, it will error)
534-
if not self.always_checkpoint and self.yield_count:
535-
self.error(TRIO107, node.lineno, node.col_offset, "iterable")
536-
elif not self.always_checkpoint and not self.safe_decorator:
537-
self.error(TRIO107, node.lineno, node.col_offset, "function")
538+
if self.always_checkpoint is not None and self.yield_count:
539+
self.error(
540+
TRIO107,
541+
node.lineno,
542+
node.col_offset,
543+
"iterable",
544+
*self.always_checkpoint,
545+
)
546+
elif self.always_checkpoint is not None and not self.safe_decorator:
547+
self.error(
548+
TRIO107,
549+
node.lineno,
550+
node.col_offset,
551+
"function",
552+
*self.always_checkpoint,
553+
)
538554

539555
self.set_state(outer)
540556

541557
def visit_Return(self, node: ast.Return):
542558
self.generic_visit(node)
543-
if not self.always_checkpoint and (not self.safe_decorator or self.yield_count):
544-
self.error(TRIO108, node.lineno, node.col_offset, "return")
559+
if self.always_checkpoint is not None and (
560+
not self.safe_decorator or self.yield_count
561+
):
562+
self.error(
563+
TRIO108, node.lineno, node.col_offset, "return", *self.always_checkpoint
564+
)
545565

546566
# avoid duplicate error messages
547-
self.always_checkpoint = True
567+
self.always_checkpoint = None
548568

549569
# disregard checkpoints in nested function definitions
550570
def visit_FunctionDef(self, node: ast.FunctionDef):
@@ -557,7 +577,7 @@ def visit_Await(self, node: Union[ast.Await, ast.Raise]):
557577
# the expression being awaited is not checkpointed
558578
# so only set checkpoint after the await node
559579
self.generic_visit(node)
560-
self.always_checkpoint = True
580+
self.always_checkpoint = None
561581

562582
# raising exception means we don't need to checkpoint so we can treat it as one
563583
visit_Raise = visit_Await
@@ -569,71 +589,81 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
569589
self.visit_nodes(node.body)
570590
# if there was no yield in body that may unset a checkpoint on exit,
571591
# treat this as a checkpoint
572-
self.always_checkpoint |= prebody_yield_count == self.yield_count
592+
if prebody_yield_count == self.yield_count:
593+
self.always_checkpoint = None
573594

574595
def visit_Yield(self, node: ast.Yield):
575596
self.generic_visit(node)
576597
self.yield_count += 1
577-
if not self.always_checkpoint:
578-
self.error(TRIO108, node.lineno, node.col_offset, "yield")
579-
self.always_checkpoint = False
598+
if self.always_checkpoint is not None:
599+
self.error(
600+
TRIO108, node.lineno, node.col_offset, "yield", *self.always_checkpoint
601+
)
602+
self.always_checkpoint = ("yield", node.lineno)
580603

581604
# valid checkpoint if there's valid checkpoints (or raise) in:
582605
# (try or else) and all excepts, or in finally
583606

584607
# try can jump into any except or into the finally at any point during it's execution
585608
# so we need to make sure except & finally can handle a worst-case exception
586609
def visit_Try(self, node: ast.Try):
587-
outer = self.get_state("always_checkpoint")
610+
# check worst case try exception
611+
body_always_checkpoint = self.always_checkpoint
612+
for n in (b for body in node.body for b in ast.walk(body)):
613+
if isinstance(n, ast.Yield):
614+
body_always_checkpoint = ("yield", n.lineno)
615+
break
588616

589617
# check try body
590618
self.visit_nodes(node.body)
591619
try_checkpoint = self.always_checkpoint
592620

593-
# checkpoint before entering try body, and no yield in it
594-
body_always_checkpoint = outer["always_checkpoint"] and not any(
595-
isinstance(n, ast.Yield) for body in node.body for n in ast.walk(body)
596-
)
597-
598621
# check that all except handlers checkpoint (await or most likely raise)
599-
all_except_checkpoint = True
622+
all_except_checkpoint: Checkpoint_Type = None
600623
for handler in node.handlers:
601624
# if there's any `yield`s in try body, exception might be thrown there
602625
self.always_checkpoint = body_always_checkpoint
603626

604627
self.visit_nodes(handler)
605-
all_except_checkpoint &= self.always_checkpoint
628+
if self.always_checkpoint is not None:
629+
all_except_checkpoint = self.always_checkpoint
630+
break
606631

607632
# check else
608633
# if else runs it's after all of try, so restore state to back then
609634
self.always_checkpoint = try_checkpoint
610635
self.visit_nodes(node.orelse)
611636

612637
# checkpoint if else checkpoints, and all excepts
613-
self.always_checkpoint &= all_except_checkpoint
638+
if all_except_checkpoint is not None:
639+
self.always_checkpoint = all_except_checkpoint
614640

615641
if node.finalbody:
616642
# if there's a finally, it can get jumped to at the worst time
617643
# from the try
618-
self.always_checkpoint &= body_always_checkpoint
644+
if body_always_checkpoint is not None:
645+
self.always_checkpoint = body_always_checkpoint
619646
self.visit_nodes(node.finalbody)
620647

621648
# valid checkpoint if both body and orelse have checkpoints
622649
def visit_If(self, node: Union[ast.If, ast.IfExp]):
623650
# visit condition
624651
self.visit_nodes(node.test)
625-
cond_yield = self.always_checkpoint
652+
outer = self.get_state("always_checkpoint")
626653

627654
# visit body
628655
self.visit_nodes(node.body)
629-
body_yield = self.always_checkpoint
656+
body_outer = self.get_state("always_checkpoint")
630657

631658
# reset to after condition and visit orelse
632-
self.always_checkpoint = cond_yield
659+
self.set_state(outer)
633660
self.visit_nodes(node.orelse)
634661

635-
# checkpoint if both body and else checkpoint
636-
self.always_checkpoint &= body_yield
662+
# if body failed, reset to that state
663+
if body_outer["always_checkpoint"] is not None:
664+
self.set_state(body_outer)
665+
666+
# otherwise keep state (fail or not) as it was after orelse
637667

638668
# inline if
639669
visit_IfExp = visit_If
@@ -649,52 +679,48 @@ def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
649679
self.visit_nodes(node.target)
650680
self.visit_nodes(node.iter)
651681

652-
self.checkpoint_continue = True
682+
self.checkpoint_continue = None
683+
pre_body_always_checkpoint = self.always_checkpoint
653684
# Async for always enters and exit loop body with checkpoint (regardless of continue)
654685
if isinstance(node, ast.AsyncFor):
655-
pre_body_always_checkpoint = True
656-
self.always_checkpoint = True
686+
pre_body_always_checkpoint = None
687+
self.always_checkpoint = None
657688

658689
# check for worst-case start of loop due to `continue` or multiple iterations
659-
else:
660-
pre_body_always_checkpoint = self.always_checkpoint
661-
690+
elif self.always_checkpoint is None:
662691
# silently check if body unsets yield
663692
# so we later can check if body errors out on worst case of entering
664693
self.suppress_errors = True
665-
self.always_checkpoint = True
666694

667695
# self.checkpoint_continue is set to False if loop body ever does
668696
# continue with self.always_checkpoint == False
669697
self.visit_nodes(node.body)
670698

671699
self.suppress_errors = outer["suppress_errors"]
672-
# enter with checkpoint only if all ways of entering are checkpointed
673-
# (first iter, continue, 2nd+ iter)
674-
self.always_checkpoint &= (
675-
pre_body_always_checkpoint and self.checkpoint_continue
676-
)
700+
701+
if self.checkpoint_continue is not None:
702+
self.always_checkpoint = self.checkpoint_continue
677703

678704
self.checkpoint_break = None
679705
self.visit_nodes(node.body)
680706

681707
if isinstance(node, ast.AsyncFor):
682-
self.always_checkpoint = True
708+
self.always_checkpoint = None
683709
else:
684710
# enter orelse with worst case: loop body might execute fully before
685711
# entering orelse, or not at all, or at a continue
686-
self.always_checkpoint &= (
687-
pre_body_always_checkpoint and self.checkpoint_continue
688-
)
712+
if pre_body_always_checkpoint is not None:
713+
self.always_checkpoint = pre_body_always_checkpoint
714+
elif self.checkpoint_continue is not None:
715+
self.always_checkpoint = self.checkpoint_continue
689716

690717
self.visit_nodes(node.orelse)
691718

692719
# We may exit from:
693720
# orelse (which covers no body, body until continue, and all body)
694721
# break
695-
self.always_checkpoint = (
696-
self.always_checkpoint and self.checkpoint_break is not False
697-
)
722+
if self.checkpoint_break is not None:
723+
self.always_checkpoint = self.checkpoint_break
698724

699725
self.set_state(outer)
700726

@@ -703,10 +729,11 @@ def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
703729
visit_AsyncFor = visit_loop
704730

705731
def visit_Continue(self, node: ast.Continue):
706-
self.checkpoint_continue &= self.always_checkpoint
732+
if self.always_checkpoint is not None:
733+
self.checkpoint_continue = self.always_checkpoint
707734

708735
def visit_Break(self, node: ast.Break):
709-
if not self.checkpoint_break:
736+
if self.always_checkpoint is not None:
710737
self.checkpoint_break = self.always_checkpoint
711738

712739
# first node in boolops can checkpoint, the others might not execute
@@ -744,4 +771,6 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
744771
TRIO105 = "TRIO105: trio async function {} must be immediately awaited"
745772
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
746773
TRIO107 = "TRIO107: async {} must have at least one checkpoint on every code path, unless an exception is raised"
747-
TRIO108 = "TRIO108: {} from async function must have at least one checkpoint on every code path before it"
774+
TRIO108 = (
775+
"TRIO108: {} from async function with no guaranteed checkpoint since {} on line {}"
776+
)

0 commit comments

Comments
 (0)