14
14
from typing import Any , Dict , Iterable , List , NamedTuple , Optional , Tuple , Type , Union
15
15
16
16
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17
- __version__ = "22.8.1 "
17
+ __version__ = "22.8.2 "
18
18
19
19
20
20
class Statement (NamedTuple ):
@@ -35,8 +35,8 @@ class Statement(NamedTuple):
35
35
"TRIO104" : "Cancelled (and therefore BaseException) must be re-raised" ,
36
36
"TRIO105" : "trio async function {} must be immediately awaited" ,
37
37
"TRIO106" : "trio must be imported with `import trio` for the linter to work" ,
38
- "TRIO107" : "Async functions must have at least one checkpoint on every code path, unless an exception is raised " ,
39
- "TRIO108" : "Early return from async function must have at least one checkpoint on every code path before it. " ,
38
+ "TRIO107" : "{0} from async function with no guaranteed checkpoint or exception since function definition on line {1.lineno} " ,
39
+ "TRIO108" : "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno} " ,
40
40
"TRIO109" : "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead" ,
41
41
"TRIO110" : "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`." ,
42
42
}
@@ -68,6 +68,7 @@ class Flake8TrioVisitor(ast.NodeVisitor):
68
68
def __init__ (self ):
69
69
super ().__init__ ()
70
70
self ._problems : List [Error ] = []
71
+ self .suppress_errors = False
71
72
72
73
@classmethod
73
74
def run (cls , tree : ast .AST ) -> Iterable [Error ]:
@@ -90,9 +91,10 @@ def visit_nodes(
90
91
visit (node )
91
92
92
93
def error (self , error : str , node : HasLineInfo , * args : Any , ** kwargs : Any ):
93
- self ._problems .append (
94
- make_error (error , node .lineno , node .col_offset , * args , ** kwargs )
95
- )
94
+ if not self .suppress_errors :
95
+ self ._problems .append (
96
+ make_error (error , node .lineno , node .col_offset , * args , ** kwargs )
97
+ )
96
98
97
99
def get_state (self , * attrs : str ) -> Dict [str , Any ]:
98
100
if not attrs :
@@ -103,6 +105,10 @@ def set_state(self, attrs: Dict[str, Any]):
103
105
for attr , value in attrs .items ():
104
106
setattr (self , attr , value )
105
107
108
+ def walk (self , * body : ast .AST ) -> Iterable [ast .AST ]:
109
+ for b in body :
110
+ yield from ast .walk (b )
111
+
106
112
107
113
class TrioScope :
108
114
def __init__ (self , node : ast .Call , funcname : str , packagename : str ):
@@ -561,105 +567,251 @@ def visit_Call(self, node: ast.Call):
561
567
class Visitor107_108 (Flake8TrioVisitor ):
562
568
def __init__ (self ):
563
569
super ().__init__ ()
564
- self .all_await = True
570
+ self .yield_count = 0
571
+
572
+ self .always_checkpoint : Optional [Statement ] = None
573
+ self .checkpoint_continue : Optional [Statement ] = None
574
+ self .checkpoint_break : Optional [Statement ] = None
575
+
576
+ self .default = self .get_state ()
565
577
566
578
def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ):
567
- outer = self .all_await
579
+ if has_decorator (node .decorator_list , "overload" ):
580
+ return
581
+
582
+ outer = self .get_state ()
583
+ self .set_state (self .default )
584
+
585
+ self .always_checkpoint = Statement ("function definition" , node .lineno )
568
586
569
- # do not require checkpointing if overloading
570
- self .all_await = has_decorator (node .decorator_list , "overload" )
571
587
self .generic_visit (node )
588
+ self .check_function_exit (node )
572
589
573
- if not self .all_await :
574
- self .error ("TRIO107" , node )
590
+ self .set_state (outer )
575
591
576
- self .all_await = outer
592
+ def check_function_exit (self , node : Union [ast .Return , ast .AsyncFunctionDef ]):
593
+ # error if function exits w/o guaranteed checkpoint since function entry
594
+ method = "return" if isinstance (node , ast .Return ) else "exit"
595
+
596
+ if self .always_checkpoint is not None :
597
+ if self .yield_count :
598
+ self .error ("TRIO108" , node , method , self .always_checkpoint )
599
+ else :
600
+ self .error ("TRIO107" , node , method , self .always_checkpoint )
577
601
578
602
def visit_Return (self , node : ast .Return ):
579
603
self .generic_visit (node )
580
- if not self .all_await :
581
- self . error ( "TRIO108" , node )
604
+ self .check_function_exit ( node )
605
+
582
606
# avoid duplicate error messages
583
- self .all_await = True
607
+ self .always_checkpoint = None
584
608
585
- # disregard raise's in nested functions
609
+ # disregard checkpoints in nested function definitions
586
610
def visit_FunctionDef (self , node : ast .FunctionDef ):
587
- outer = self .all_await
611
+ outer = self .get_state ()
612
+ self .set_state (self .default )
588
613
self .generic_visit (node )
589
- self .all_await = outer
614
+ self .set_state ( outer )
590
615
591
616
# checkpoint functions
592
- def visit_Await (
593
- self , node : Union [ ast . Await , ast . AsyncFor , ast . AsyncWith , ast . Raise ]
594
- ):
617
+ def visit_Await (self , node : Union [ ast . Await , ast . Raise ]):
618
+ # the expression being awaited is not checkpointed
619
+ # so only set checkpoint after the await node
595
620
self .generic_visit (node )
596
- self .all_await = True
597
-
598
- visit_AsyncFor = visit_Await
599
- visit_AsyncWith = visit_Await
621
+ self .always_checkpoint = None
600
622
601
623
# raising exception means we don't need to checkpoint so we can treat it as one
602
624
visit_Raise = visit_Await
603
625
604
- # valid checkpoint if there's valid checkpoints (or raise) in at least one of:
605
- # (try or else) and all excepts
606
- # finally
626
+ # guaranteed to checkpoint on at least one of enter and exit
627
+ # if it checkpoints on entry and there's a yield in it, we can't treat it as checkpoint
628
+ # but it may not checkpoint on entry, so yields inside need to raise problem
629
+ def visit_AsyncWith (self , node : ast .AsyncWith ):
630
+ self .visit_nodes (node .items )
631
+ prebody_yield_count = self .yield_count
632
+
633
+ # there's no guarantee of checkpoint before entry
634
+ self .visit_nodes (node .body )
635
+
636
+ # no yield in body, treat as checkpoint
637
+ if prebody_yield_count == self .yield_count :
638
+ self .always_checkpoint = None
639
+
640
+ # error if no checkpoint since earlier yield or function entry
641
+ def visit_Yield (self , node : ast .Yield ):
642
+ self .generic_visit (node )
643
+ self .yield_count += 1
644
+ if self .always_checkpoint is not None :
645
+ self .error ("TRIO108" , node , "yield" , self .always_checkpoint )
646
+
647
+ # mark as requiring checkpoint after
648
+ self .always_checkpoint = Statement ("yield" , node .lineno )
649
+
650
+ # valid checkpoint if there's valid checkpoints (or raise) in:
651
+ # (try or else) and all excepts, or in finally
652
+ #
653
+ # try can jump into any except or into the finally* at any point during it's
654
+ # execution so we need to make sure except & finally can handle worst-case
655
+ # * unless there's a bare except / except BaseException - not implemented.
607
656
def visit_Try (self , node : ast .Try ):
608
- if self .all_await :
609
- self .generic_visit (node )
610
- return
657
+ # except & finally guaranteed to enter with checkpoint if checkpointed
658
+ # before try and no yield in try body.
659
+ body_always_checkpoint = self .always_checkpoint
660
+ for inner_node in self .walk (* node .body ):
661
+ if isinstance (inner_node , ast .Yield ):
662
+ body_always_checkpoint = Statement ("yield" , inner_node .lineno )
663
+ break
611
664
612
665
# check try body
613
666
self .visit_nodes (node .body )
614
- body_await = self .all_await
615
- self .all_await = False
667
+
668
+ # save state at end of try for entering else
669
+ try_checkpoint = self .always_checkpoint
616
670
617
671
# check that all except handlers checkpoint (await or most likely raise)
618
- all_except_await = True
672
+ all_except_checkpoint : Optional [ Statement ] = None
619
673
for handler in node .handlers :
674
+ # enter with worst case of try
675
+ self .always_checkpoint = body_always_checkpoint
676
+
620
677
self .visit_nodes (handler )
621
- all_except_await &= self .all_await
622
- self .all_await = False
678
+
679
+ if self .always_checkpoint is not None :
680
+ all_except_checkpoint = self .always_checkpoint
623
681
624
682
# check else
683
+ # if else runs it's after all of try, so restore state to back then
684
+ self .always_checkpoint = try_checkpoint
625
685
self .visit_nodes (node .orelse )
626
686
627
- # (try or else) and all excepts
628
- self .all_await = (body_await or self .all_await ) and all_except_await
687
+ # checkpoint if else checkpoints, and all excepts checkpoint
688
+ if all_except_checkpoint is not None :
689
+ self .always_checkpoint = all_except_checkpoint
629
690
630
- # finally can check on it's own
631
- self .visit_nodes (node .finalbody )
691
+ # if there's no finally, don't restore state from try
692
+ if node .finalbody :
693
+ # can enter from try, else, or any except
694
+ if body_always_checkpoint is not None :
695
+ self .always_checkpoint = body_always_checkpoint
696
+ self .visit_nodes (node .finalbody )
632
697
633
- # valid checkpoint if both body and orelse have checkpoints
698
+ # valid checkpoint if both body and orelse checkpoint
634
699
def visit_If (self , node : Union [ast .If , ast .IfExp ]):
635
- if self .all_await :
636
- self .generic_visit (node )
637
- return
638
-
639
- # ignore checkpoints in condition
700
+ # visit condition
640
701
self .visit_nodes (node .test )
641
- self . all_await = False
702
+ outer = self . get_state ( "always_checkpoint" )
642
703
643
- # check body
704
+ # visit body
644
705
self .visit_nodes (node .body )
645
- body_await = self .all_await
646
- self .all_await = False
706
+ body_outer = self .get_state ("always_checkpoint" )
647
707
708
+ # reset to after condition and visit orelse
709
+ self .set_state (outer )
648
710
self .visit_nodes (node .orelse )
649
711
650
- # checkpoint if both body and else
651
- self .all_await = body_await and self .all_await
712
+ # if body failed, reset to that state
713
+ if body_outer ["always_checkpoint" ] is not None :
714
+ self .set_state (body_outer )
715
+
716
+ # otherwise keep state (fail or not) as it was after orelse
652
717
653
718
# inline if
654
719
visit_IfExp = visit_If
655
720
656
- # ignore checkpoints in loops due to continue/break shenanigans
657
- def visit_While (self , node : Union [ast .While , ast .For ]):
658
- outer = self .all_await
659
- self .generic_visit (node )
660
- self .all_await = outer
721
+ # Check for yields w/o checkpoint inbetween due to entering loop body the first time,
722
+ # after completing all of loop body, and after any continues.
723
+ # yield in else have same requirement
724
+ # state after the loop same as above, and in addition the state at any break
725
+ def visit_loop (self , node : Union [ast .While , ast .For , ast .AsyncFor ]):
726
+ # save state in case of nested loops
727
+ outer = self .get_state (
728
+ "checkpoint_continue" , "checkpoint_break" , "suppress_errors"
729
+ )
730
+
731
+ # visit condition
732
+ if isinstance (node , ast .While ):
733
+ self .visit_nodes (node .test )
734
+ else :
735
+ self .visit_nodes (node .target )
736
+ self .visit_nodes (node .iter )
737
+
738
+ self .checkpoint_continue = None
739
+ pre_body_always_checkpoint = self .always_checkpoint
740
+
741
+ # AsyncFor guaranteed checkpoint at every iteration
742
+ if isinstance (node , ast .AsyncFor ):
743
+ pre_body_always_checkpoint = None
744
+ self .always_checkpoint = None
745
+
746
+ # if we normally enter loop with checkpoint, check for worst-case start of loop
747
+ # due to `continue` or multiple iterations
748
+ elif self .always_checkpoint is None :
749
+ # silently check if body unsets yield
750
+ # so we later can check if body errors out on worst case of entering
751
+ self .suppress_errors = True
752
+
753
+ # self.checkpoint_continue is set to False if loop body ever does
754
+ # continue with self.always_checkpoint == False
755
+ self .visit_nodes (node .body )
756
+
757
+ self .suppress_errors = outer ["suppress_errors" ]
758
+
759
+ if self .checkpoint_continue is not None :
760
+ self .always_checkpoint = self .checkpoint_continue
761
+
762
+ self .checkpoint_break = None
763
+ self .visit_nodes (node .body )
764
+
765
+ # AsyncFor guarantees checkpoint on running out of iterable
766
+ # so reset checkpoint state at end of loop. (but not state at break)
767
+ if isinstance (node , ast .AsyncFor ):
768
+ self .always_checkpoint = None
769
+ else :
770
+ # enter orelse with worst case:
771
+ # loop body might execute fully before entering orelse
772
+ # (current state of self.always_checkpoint)
773
+ # or not at all
774
+ if pre_body_always_checkpoint is not None :
775
+ self .always_checkpoint = pre_body_always_checkpoint
776
+ # or at a continue
777
+ elif self .checkpoint_continue is not None :
778
+ self .always_checkpoint = self .checkpoint_continue
779
+
780
+ # visit orelse
781
+ self .visit_nodes (node .orelse )
782
+
783
+ # We may exit from:
784
+ # orelse (which covers no body, body until continue, and all body)
785
+ # break
786
+ if self .checkpoint_break is not None :
787
+ self .always_checkpoint = self .checkpoint_break
788
+
789
+ # reset state in case of nested loops
790
+ self .set_state (outer )
661
791
662
- visit_For = visit_While
792
+ visit_While = visit_loop
793
+ visit_For = visit_loop
794
+ visit_AsyncFor = visit_loop
795
+
796
+ # save state in case of continue/break at a point not guaranteed to checkpoint
797
+ def visit_Continue (self , node : ast .Continue ):
798
+ if self .always_checkpoint is not None :
799
+ self .checkpoint_continue = self .always_checkpoint
800
+
801
+ def visit_Break (self , node : ast .Break ):
802
+ if self .always_checkpoint is not None :
803
+ self .checkpoint_break = self .always_checkpoint
804
+
805
+ # first node in a condition is guaranteed to run, but may shortcut so checkpoints
806
+ # in remaining nodes are not guaranteed
807
+ # Not fully implemented: worst case shortcut with yields in condition
808
+ def visit_BoolOp (self , node : ast .BoolOp ):
809
+ self .visit (node .op )
810
+ self .visit_nodes (node .values [:1 ])
811
+ outer = self .always_checkpoint
812
+ self .visit_nodes (node .values [1 :])
813
+
814
+ self .always_checkpoint = outer
663
815
664
816
665
817
class Plugin :
0 commit comments