@@ -507,21 +507,25 @@ def visit_Call(self, node: ast.Call):
507
507
self .generic_visit (node )
508
508
509
509
510
+ Checkpoint_Type = Optional [Tuple [str , int ]]
511
+
512
+
510
513
class Visitor107_108 (Flake8TrioVisitor ):
511
514
def __init__ (self ):
512
515
super ().__init__ ()
513
- self .always_checkpoint = True
514
516
self .yield_count = 0
515
- self .checkpoint_continue = True
516
- self .checkpoint_break : Optional [bool ] = None
517
+
518
+ self .always_checkpoint : Checkpoint_Type = None
519
+ self .checkpoint_continue : Checkpoint_Type = None
520
+ self .checkpoint_break : Checkpoint_Type = None
517
521
518
522
def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ):
519
523
if has_decorator (node .decorator_list , "overload" ):
520
524
return
521
525
522
526
outer = self .get_state ()
523
527
524
- self .always_checkpoint = False
528
+ self .always_checkpoint = ( "function definition" , node . lineno )
525
529
self .safe_decorator = has_decorator (node .decorator_list , * context_manager_names )
526
530
527
531
self .generic_visit (node )
@@ -531,20 +535,36 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
531
535
# we ignore contextmanager decorators with yields since the decorator may
532
536
# checkpoint on entry, which the yield then unsets. (and if there's no explicit
533
537
# checkpoint before yield, it will error)
534
- if not self .always_checkpoint and self .yield_count :
535
- self .error (TRIO107 , node .lineno , node .col_offset , "iterable" )
536
- elif not self .always_checkpoint and not self .safe_decorator :
537
- self .error (TRIO107 , node .lineno , node .col_offset , "function" )
538
+ if self .always_checkpoint is not None and self .yield_count :
539
+ self .error (
540
+ TRIO107 ,
541
+ node .lineno ,
542
+ node .col_offset ,
543
+ "iterable" ,
544
+ * self .always_checkpoint ,
545
+ )
546
+ elif self .always_checkpoint is not None and not self .safe_decorator :
547
+ self .error (
548
+ TRIO107 ,
549
+ node .lineno ,
550
+ node .col_offset ,
551
+ "function" ,
552
+ * self .always_checkpoint ,
553
+ )
538
554
539
555
self .set_state (outer )
540
556
541
557
def visit_Return (self , node : ast .Return ):
542
558
self .generic_visit (node )
543
- if not self .always_checkpoint and (not self .safe_decorator or self .yield_count ):
544
- self .error (TRIO108 , node .lineno , node .col_offset , "return" )
559
+ if self .always_checkpoint is not None and (
560
+ not self .safe_decorator or self .yield_count
561
+ ):
562
+ self .error (
563
+ TRIO108 , node .lineno , node .col_offset , "return" , * self .always_checkpoint
564
+ )
545
565
546
566
# avoid duplicate error messages
547
- self .always_checkpoint = True
567
+ self .always_checkpoint = None
548
568
549
569
# disregard checkpoints in nested function definitions
550
570
def visit_FunctionDef (self , node : ast .FunctionDef ):
@@ -557,7 +577,7 @@ def visit_Await(self, node: Union[ast.Await, ast.Raise]):
557
577
# the expression being awaited is not checkpointed
558
578
# so only set checkpoint after the await node
559
579
self .generic_visit (node )
560
- self .always_checkpoint = True
580
+ self .always_checkpoint = None
561
581
562
582
# raising exception means we don't need to checkpoint so we can treat it as one
563
583
visit_Raise = visit_Await
@@ -569,71 +589,81 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
569
589
self .visit_nodes (node .body )
570
590
# if there was no yield in body that may unset a checkpoint on exit,
571
591
# treat this as a checkpoint
572
- self .always_checkpoint |= prebody_yield_count == self .yield_count
592
+ if prebody_yield_count == self .yield_count :
593
+ self .always_checkpoint = None
573
594
574
595
def visit_Yield (self , node : ast .Yield ):
575
596
self .generic_visit (node )
576
597
self .yield_count += 1
577
- if not self .always_checkpoint :
578
- self .error (TRIO108 , node .lineno , node .col_offset , "yield" )
579
- self .always_checkpoint = False
598
+ if self .always_checkpoint is not None :
599
+ self .error (
600
+ TRIO108 , node .lineno , node .col_offset , "yield" , * self .always_checkpoint
601
+ )
602
+ self .always_checkpoint = ("yield" , node .lineno )
580
603
581
604
# valid checkpoint if there's valid checkpoints (or raise) in:
582
605
# (try or else) and all excepts, or in finally
583
606
584
607
# try can jump into any except or into the finally at any point during it's execution
585
608
# so we need to make sure except & finally can handle a worst-case exception
586
609
def visit_Try (self , node : ast .Try ):
587
- outer = self .get_state ("always_checkpoint" )
610
+ # check worst case try exception
611
+ body_always_checkpoint = self .always_checkpoint
612
+ for n in (b for body in node .body for b in ast .walk (body )):
613
+ if isinstance (n , ast .Yield ):
614
+ body_always_checkpoint = ("yield" , n .lineno )
615
+ break
588
616
589
617
# check try body
590
618
self .visit_nodes (node .body )
591
619
try_checkpoint = self .always_checkpoint
592
620
593
- # checkpoint before entering try body, and no yield in it
594
- body_always_checkpoint = outer ["always_checkpoint" ] and not any (
595
- isinstance (n , ast .Yield ) for body in node .body for n in ast .walk (body )
596
- )
597
-
598
621
# check that all except handlers checkpoint (await or most likely raise)
599
- all_except_checkpoint = True
622
+ all_except_checkpoint : Checkpoint_Type = None
600
623
for handler in node .handlers :
601
624
# if there's any `yield`s in try body, exception might be thrown there
602
625
self .always_checkpoint = body_always_checkpoint
603
626
604
627
self .visit_nodes (handler )
605
- all_except_checkpoint &= self .always_checkpoint
628
+ if self .always_checkpoint is not None :
629
+ all_except_checkpoint = self .always_checkpoint
630
+ break
606
631
607
632
# check else
608
633
# if else runs it's after all of try, so restore state to back then
609
634
self .always_checkpoint = try_checkpoint
610
635
self .visit_nodes (node .orelse )
611
636
612
637
# checkpoint if else checkpoints, and all excepts
613
- self .always_checkpoint &= all_except_checkpoint
638
+ if all_except_checkpoint is not None :
639
+ self .always_checkpoint = all_except_checkpoint
614
640
615
641
if node .finalbody :
616
642
# if there's a finally, it can get jumped to at the worst time
617
643
# from the try
618
- self .always_checkpoint &= body_always_checkpoint
644
+ if body_always_checkpoint is not None :
645
+ self .always_checkpoint = body_always_checkpoint
619
646
self .visit_nodes (node .finalbody )
620
647
621
648
# valid checkpoint if both body and orelse have checkpoints
622
649
def visit_If (self , node : Union [ast .If , ast .IfExp ]):
623
650
# visit condition
624
651
self .visit_nodes (node .test )
625
- cond_yield = self .always_checkpoint
652
+ outer = self .get_state ( " always_checkpoint" )
626
653
627
654
# visit body
628
655
self .visit_nodes (node .body )
629
- body_yield = self .always_checkpoint
656
+ body_outer = self .get_state ( " always_checkpoint" )
630
657
631
658
# reset to after condition and visit orelse
632
- self .always_checkpoint = cond_yield
659
+ self .set_state ( outer )
633
660
self .visit_nodes (node .orelse )
634
661
635
- # checkpoint if both body and else checkpoint
636
- self .always_checkpoint &= body_yield
662
+ # if body failed, reset to that state
663
+ if body_outer ["always_checkpoint" ] is not None :
664
+ self .set_state (body_outer )
665
+
666
+ # otherwise keep state (fail or not) as it was after orelse
637
667
638
668
# inline if
639
669
visit_IfExp = visit_If
@@ -649,52 +679,48 @@ def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
649
679
self .visit_nodes (node .target )
650
680
self .visit_nodes (node .iter )
651
681
652
- self .checkpoint_continue = True
682
+ self .checkpoint_continue = None
683
+ pre_body_always_checkpoint = self .always_checkpoint
653
684
# Async for always enters and exit loop body with checkpoint (regardless of continue)
654
685
if isinstance (node , ast .AsyncFor ):
655
- pre_body_always_checkpoint = True
656
- self .always_checkpoint = True
686
+ pre_body_always_checkpoint = None
687
+ self .always_checkpoint = None
657
688
658
689
# check for worst-case start of loop due to `continue` or multiple iterations
659
- else :
660
- pre_body_always_checkpoint = self .always_checkpoint
661
-
690
+ elif self .always_checkpoint is None :
662
691
# silently check if body unsets yield
663
692
# so we later can check if body errors out on worst case of entering
664
693
self .suppress_errors = True
665
- self .always_checkpoint = True
666
694
667
695
# self.checkpoint_continue is set to False if loop body ever does
668
696
# continue with self.always_checkpoint == False
669
697
self .visit_nodes (node .body )
670
698
671
699
self .suppress_errors = outer ["suppress_errors" ]
672
- # enter with checkpoint only if all ways of entering are checkpointed
673
- # (first iter, continue, 2nd+ iter)
674
- self .always_checkpoint &= (
675
- pre_body_always_checkpoint and self .checkpoint_continue
676
- )
700
+
701
+ if self .checkpoint_continue is not None :
702
+ self .always_checkpoint = self .checkpoint_continue
677
703
678
704
self .checkpoint_break = None
679
705
self .visit_nodes (node .body )
680
706
681
707
if isinstance (node , ast .AsyncFor ):
682
- self .always_checkpoint = True
708
+ self .always_checkpoint = None
683
709
else :
684
710
# enter orelse with worst case: loop body might execute fully before
685
711
# entering orelse, or not at all, or at a continue
686
- self .always_checkpoint &= (
687
- pre_body_always_checkpoint and self .checkpoint_continue
688
- )
712
+ if pre_body_always_checkpoint is not None :
713
+ self .always_checkpoint = pre_body_always_checkpoint
714
+ elif self .checkpoint_continue is not None :
715
+ self .always_checkpoint = self .checkpoint_continue
689
716
690
717
self .visit_nodes (node .orelse )
691
718
692
719
# We may exit from:
693
720
# orelse (which covers no body, body until continue, and all body)
694
721
# break
695
- self .always_checkpoint = (
696
- self .always_checkpoint and self .checkpoint_break is not False
697
- )
722
+ if self .checkpoint_break is not None :
723
+ self .always_checkpoint = self .checkpoint_break
698
724
699
725
self .set_state (outer )
700
726
@@ -703,10 +729,11 @@ def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
703
729
visit_AsyncFor = visit_loop
704
730
705
731
def visit_Continue (self , node : ast .Continue ):
706
- self .checkpoint_continue &= self .always_checkpoint
732
+ if self .always_checkpoint is not None :
733
+ self .checkpoint_continue = self .always_checkpoint
707
734
708
735
def visit_Break (self , node : ast .Break ):
709
- if not self .checkpoint_break :
736
+ if self .always_checkpoint is not None :
710
737
self .checkpoint_break = self .always_checkpoint
711
738
712
739
# first node in boolops can checkpoint, the others might not execute
@@ -744,4 +771,6 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
744
771
TRIO105 = "TRIO105: trio async function {} must be immediately awaited"
745
772
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
746
773
TRIO107 = "TRIO107: async {} must have at least one checkpoint on every code path, unless an exception is raised"
747
- TRIO108 = "TRIO108: {} from async function must have at least one checkpoint on every code path before it"
774
+ TRIO108 = (
775
+ "TRIO108: {} from async function with no guaranteed checkpoint since {} on line {}"
776
+ )
0 commit comments