diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8abd2cc1..cb86fc78 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,8 +32,10 @@ To check that all codes are tested and documented there's a test that error code ## Test generator Tests are automatically generated for files in the `tests/eval_files/` directory, with the code that it's testing interpreted from the file name. The file extension is split off, if there's a match for for `_py\d*` it strips that off and uses it to determine if there's a minimum python version for which the test should only run. -### autofix files -Checks that have autofixing can have a file in the `tests/autofix_files` directory matching the filename in `tests/eval_files`. The result of running the checker on the eval file with autofix enabled will then be compared to the content of the autofix file and will print a diff (if `-s` is on) and assert that the content is the same. `--generate-autofix` is added as a pytest flag to ease development, which will print a diff (with `-s`) and overwrite the content of the autofix file. Also see the magic line marker `pass # AUTOFIX_LINE ` below +### `# AUTOFIX` +Files in `tests/eval_files` with this marker will have two files in `tests/autofix_files/`. One with the same name containing the code after being autofixed, and a diff file between those two. +During tests the result of running the checker on the eval file with autofix enabled will be compared to the content of the autofix file and will print a diff (if `-s` is on) and assert that the content is the same. `--generate-autofix` is added as a pytest flag to ease development, which will print a diff (with `-s`) and overwrite the content of the autofix file. +Files without this marker will be checked that they *don't* modify the file content. ### `error:` Lines containing `error:` are parsed as expecting an error of the code matching the file name, with everything on the line after the colon `eval`'d and passed as arguments to `flake8_trio.Error_codes[].str_format`. The `globals` argument to `eval` contains a `lineno` variable assigned the current line number, and the `flake8_trio.Statement` namedtuple. The first element after `error:` *must* be an integer containing the column where the error on that line originates. diff --git a/flake8_trio/visitors/visitor91x.py b/flake8_trio/visitors/visitor91x.py index 64d24d15..a11ef79d 100644 --- a/flake8_trio/visitors/visitor91x.py +++ b/flake8_trio/visitors/visitor91x.py @@ -7,6 +7,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -32,7 +33,7 @@ def func_empty_body(node: cst.FunctionDef) -> bool: - # Does the function body consist solely of `pass`, `...`, and (doc)string literals? + """Check if function body consist of `pass`, `...`, and/or (doc)string literals.""" empty_statement = m.Pass() | m.Expr(m.Ellipsis() | m.SimpleString()) return m.matches( node.body, @@ -90,9 +91,125 @@ def checkpoint_statement(library: str) -> cst.SimpleStatementLine: ) +class CommonVisitors(cst.CSTTransformer, ABC): + """Base class for InsertCheckpointsInLoopBody and Visitor91X. + + Contains the transform methods used to actually insert the checkpoints, as well + as making sure that the library used will get imported. Adding the library import + is done in Visitor91X. + """ + + def __init__(self): + super().__init__() + self.noautofix: bool = False + self.add_statement: cst.SimpleStatementLine | None = None + + self.explicitly_imported_library: dict[str, bool] = { + "trio": False, + "anyio": False, + } + self.add_import: set[str] = set() + + self.__booldepth = 0 + + @property + @abstractmethod + def library(self) -> tuple[str, ...]: + ... + + # instead of trying to exclude yields found in all the weird places from + # setting self.add_statement, we instead clear it upon each new line. + # Several of them *could* be handled, e.g. `if ...: yield`, but + # that's uncommon enough we don't care about it. + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine): + self.add_statement = None + + # insert checkpoint before yield with a flattensentinel, if indicated + def leave_SimpleStatementLine( + self, + original_node: cst.SimpleStatementLine, + updated_node: cst.SimpleStatementLine, + ) -> cst.SimpleStatementLine | cst.FlattenSentinel[cst.SimpleStatementLine]: + # possible TODO: generate an error if transforming+visiting is done in a + # single pass and emit-error-on-transform can be enabled/disabled. The error can't + # be generated in the yield/return since it doesn't know if it will be autofixed. + if self.add_statement is None: + return updated_node + curr_add_statement = self.add_statement + self.add_statement = None + + # multiple statements on a single line is not handled + if len(updated_node.body) > 1: + return updated_node + + self.ensure_imported_library() + return cst.FlattenSentinel([curr_add_statement, updated_node]) + + def visit_BooleanOperation(self, node: cst.BooleanOperation): + self.__booldepth += 1 + self.noautofix = True + + def leave_BooleanOperation( + self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation + ): + assert self.__booldepth + self.__booldepth -= 1 + if not self.__booldepth: + self.noautofix = False + return updated_node + + def ensure_imported_library(self) -> None: + """Mark library for import. + + Check that the library we'd use to insert checkpoints + is imported - if not, mark it to be inserted later. + """ + assert self.library + if not self.explicitly_imported_library[self.library[0]]: + self.add_import.add(self.library[0]) + + +class InsertCheckpointsInLoopBody(CommonVisitors): + """Insert checkpoints in loop bodies. + + This inserts checkpoints that it was not known on the first pass whether a + checkpoint would be necessary, i.e. no uncheckpointed statements as we started to + parse the loop, but then there's uncheckpointed statements on continue or as loop + body finishes. + Called from `leave_While` and `leave_For` in Visitor91X. + """ + + def __init__( + self, + nodes_needing_checkpoint: Sequence[cst.Yield | cst.Return], + library: tuple[str, ...], + ): + super().__init__() + self.nodes_needing_checkpoint = nodes_needing_checkpoint + self.__library = library + + @property + def library(self) -> tuple[str, ...]: + return self.__library if self.__library else ("trio",) + + def leave_Yield( + self, + original_node: cst.Yield, + updated_node: cst.Yield, + ) -> cst.Yield: + # Needs to be passed *original* node, since updated node is a copy + # which loses identity equality + if original_node in self.nodes_needing_checkpoint and not self.noautofix: + self.add_statement = checkpoint_statement(self.library[0]) + return updated_node + + # returns handled same as yield, but ofc needs to ignore types + leave_Return = leave_Yield # type: ignore + + @error_class_cst @disabled_by_default -class Visitor91X(Flake8TrioVisitor_cst): +class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors): error_codes = { "TRIO910": ( "{0} from async function with no guaranteed checkpoint or exception " @@ -112,69 +229,18 @@ def __init__(self, *args: Any, **kwargs: Any): self.uncheckpointed_statements: set[Statement] = set() self.comp_unknown = False - # these are file-wide, so intentionally not save-stated upon entry/exit - # of functions/loops/etc - self.explicitly_imported_library: dict[str, bool] = { - "trio": False, - "anyio": False, - } - self.add_import: set[str] = set() - - # this one is not save-stated, but I fail to come up with any scenario - # where that matters - self.add_statement: cst.SimpleStatementLine | None = None - self.loop_state = LoopState() self.try_state = TryState() - @property - def infinite_loop(self) -> bool: - return self.loop_state.infinite_loop - - @infinite_loop.setter - def infinite_loop(self, value: bool): - self.loop_state.infinite_loop = value - - @property - def body_guaranteed_once(self) -> bool: - return self.loop_state.body_guaranteed_once - - @body_guaranteed_once.setter - def body_guaranteed_once(self, value: bool): - self.loop_state.body_guaranteed_once = value - - @property - def has_break(self) -> bool: - return self.loop_state.has_break - - @has_break.setter - def has_break(self, value: bool): - self.loop_state.has_break = value - - @property - def uncheckpointed_before_continue(self) -> set[Statement]: - return self.loop_state.uncheckpointed_before_continue - - @uncheckpointed_before_continue.setter - def uncheckpointed_before_continue(self, value: set[Statement]): - self.loop_state.uncheckpointed_before_continue = value - - @property - def uncheckpointed_before_break(self) -> set[Statement]: - return self.loop_state.uncheckpointed_before_break - - @uncheckpointed_before_break.setter - def uncheckpointed_before_break(self, value: set[Statement]): - self.loop_state.uncheckpointed_before_break = value - def checkpoint_statement(self) -> cst.SimpleStatementLine: - self.ensure_imported_library() return checkpoint_statement(self.library[0]) - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # don't lint functions whose bodies solely consist of pass or ellipsis + # @overload functions are also guaranteed to be empty + # we also ignore pytest fixtures if func_has_decorator(node, "overload", "fixture") or func_empty_body(node): - return + return False # subnodes can be ignored self.save_state( node, @@ -197,12 +263,15 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: ) ) if not self.async_function: - return + # only visit subnodes if there is an async function defined inside + # this should improve performance on codebases with many sync functions + return any(m.findall(node, m.FunctionDef(asynchronous=m.Asynchronous()))) pos = self.get_metadata(PositionProvider, node).start self.uncheckpointed_statements = { Statement("function definition", pos.line, pos.column) } + return True def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef @@ -248,7 +317,7 @@ def check_function_exit( # Add this as a node potentially needing checkpoints only if it # missing checkpoints solely depends on whether the artificial statement is # "real" - if len(self.uncheckpointed_statements) == 1: + if len(self.uncheckpointed_statements) == 1 and not self.noautofix: self.loop_state.nodes_needing_checkpoints.append(original_node) return False @@ -274,38 +343,6 @@ def leave_Return( assert original_node.deep_equals(updated_node) return original_node - # TODO: generate an error in these two if transforming+visiting is done in a single - # pass and emit-error-on-transform can be enabled/disabled. The error can't be - # generated in the yield/return since it doesn't know if it will be autofixed. - - # SimpleStatementSuite and multi-statement lines can probably be autofixed, but - # for now just don't insert checkpoints in the wrong place. - def leave_SimpleStatementSuite( - self, - original_node: cst.SimpleStatementSuite, - updated_node: cst.SimpleStatementSuite, - ) -> cst.SimpleStatementSuite: - self.add_statement = None - return updated_node - - # insert checkpoint before yield with a flattensentinel, if indicated - def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, - updated_node: cst.SimpleStatementLine, - ) -> cst.SimpleStatementLine | cst.FlattenSentinel[cst.SimpleStatementLine]: - if self.add_statement is None: - return updated_node - - # multiple statements on a single line is not handled - if len(updated_node.body) > 1: - self.add_statement = None - return updated_node - - res = cst.FlattenSentinel([self.add_statement, updated_node]) - self.add_statement = None - return res # noqa: R504 - def error_91x( self, node: cst.Return | cst.FunctionDef | cst.Yield, @@ -356,7 +393,7 @@ def leave_Yield( return updated_node self.has_yield = True - if self.check_function_exit(original_node): + if self.check_function_exit(original_node) and not self.noautofix: self.add_statement = self.checkpoint_statement() # mark as requiring checkpoint after @@ -482,7 +519,7 @@ def visit_IfExp(self, node: cst.IfExp) -> bool: self.leave_If_body(node) _ = node.orelse.visit(self) self.leave_If(node, node) # type: ignore - return False + return False # libcst shouldn't visit subnodes again def visit_While(self, node: cst.While | cst.For): self.save_state( @@ -491,7 +528,7 @@ def visit_While(self, node: cst.While | cst.For): copy=True, ) self.loop_state = LoopState() - self.infinite_loop = self.body_guaranteed_once = False + self.loop_state.infinite_loop = self.loop_state.body_guaranteed_once = False visit_For = visit_While @@ -505,10 +542,10 @@ def visit_While_test(self, node: cst.While): if (m.matches(node.test, m.Name("True"))) or ( getattr(node.test, "evaluated_value", False) ): - self.infinite_loop = self.body_guaranteed_once = True + self.loop_state.infinite_loop = self.loop_state.body_guaranteed_once = True def visit_For_iter(self, node: cst.For): - self.body_guaranteed_once = iter_guaranteed_once_cst(node.iter) + self.loop_state.body_guaranteed_once = iter_guaranteed_once_cst(node.iter) def visit_While_body(self, node: cst.For | cst.While): if not self.async_function: @@ -528,8 +565,8 @@ def visit_While_body(self, node: cst.For | cst.While): else: self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} - self.uncheckpointed_before_continue = set() - self.uncheckpointed_before_break = set() + self.loop_state.uncheckpointed_before_continue = set() + self.loop_state.uncheckpointed_before_break = set() visit_For_body = visit_While_body @@ -545,7 +582,7 @@ def leave_While_body(self, node: cst.For | cst.While): for stmt in ( self.outer[node]["uncheckpointed_statements"] | self.uncheckpointed_statements - | self.uncheckpointed_before_continue + | self.loop_state.uncheckpointed_before_continue ): if stmt == ARTIFICIAL_STATEMENT: continue @@ -560,8 +597,8 @@ def leave_While_body(self, node: cst.For | cst.While): # replace artificial statements in else with prebody uncheckpointed statements # non-artificial stmts before continue/break/at body end will already be in them for stmts in ( - self.uncheckpointed_before_continue, - self.uncheckpointed_before_break, + self.loop_state.uncheckpointed_before_continue, + self.loop_state.uncheckpointed_before_break, self.uncheckpointed_statements, ): if ARTIFICIAL_STATEMENT in stmts: @@ -577,14 +614,14 @@ def leave_While_body(self, node: cst.For | cst.While): # loop body might execute fully before entering orelse # (current state of self.uncheckpointed_statements) # or not at all - if not self.body_guaranteed_once: + if not self.loop_state.body_guaranteed_once: self.uncheckpointed_statements.update( self.outer[node]["uncheckpointed_statements"] ) # or at a continue, unless it's an infinite loop - if not self.infinite_loop: + if not self.loop_state.infinite_loop: self.uncheckpointed_statements.update( - self.uncheckpointed_before_continue + self.loop_state.uncheckpointed_before_continue ) leave_For_body = leave_While_body @@ -594,13 +631,15 @@ def leave_While_orelse(self, node: cst.For | cst.While): return # if this is an infinite loop, with no break in it, don't raise # alarms about the state after it. - if self.infinite_loop and not self.has_break: + if self.loop_state.infinite_loop and not self.loop_state.has_break: self.uncheckpointed_statements = set() else: # We may exit from: # orelse (covering: no body, body until continue, and all body) # break - self.uncheckpointed_statements.update(self.uncheckpointed_before_break) + self.uncheckpointed_statements.update( + self.loop_state.uncheckpointed_before_break + ) # reset break & continue in case of nested loops self.outer[node]["uncheckpointed_statements"] = self.uncheckpointed_statements @@ -616,9 +655,8 @@ def leave_While( | cst.RemovalSentinel ): if self.loop_state.nodes_needing_checkpoints: - self.ensure_imported_library() transformer = InsertCheckpointsInLoopBody( - self.loop_state.nodes_needing_checkpoints, self.library[0] + self.loop_state.nodes_needing_checkpoints, self.library ) # type of updated_node expanded to the return type updated_node = updated_node.visit(transformer) # type: ignore @@ -633,16 +671,20 @@ def leave_While( def visit_Continue(self, node: cst.Continue): if not self.async_function: return - self.uncheckpointed_before_continue.update(self.uncheckpointed_statements) + self.loop_state.uncheckpointed_before_continue.update( + self.uncheckpointed_statements + ) def visit_Break(self, node: cst.Break): - self.has_break = True + self.loop_state.has_break = True if not self.async_function: return - self.uncheckpointed_before_break.update(self.uncheckpointed_statements) + self.loop_state.uncheckpointed_before_break.update( + self.uncheckpointed_statements + ) - # first node in a condition is always evaluated, but may shortcut at any point - # after that so we track worst-case checkpoint (i.e. after yield) + # we visit BooleanOperation_left as usual, but ignore checkpoints in the + # right-hand side while still adding any yields in it. def visit_BooleanOperation_right(self, node: cst.BooleanOperation): if not self.async_function: return @@ -718,7 +760,7 @@ def visit_CompFor(self, node: cst.CompFor): return False # We don't have any logic on if generators are guaranteed to unroll, so always - # ignore their content + # ignore their content by not visiting subnodes. def visit_GeneratorExp(self, node: cst.GeneratorExp): return False @@ -735,15 +777,6 @@ def visit_Import(self, node: cst.Import): assert isinstance(alias.name.value, str) self.explicitly_imported_library[alias.name.value] = True - def ensure_imported_library(self) -> None: - """Mark library for import. - - Check that the library we'd use to insert checkpoints - is imported - if not, mark it to be inserted later. - """ - if not self.explicitly_imported_library[self.library[0]]: - self.add_import.add(self.library[0]) - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module): """Add needed library import, if any, to the module.""" if not self.add_import: @@ -765,33 +798,3 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module): assert len(self.add_import) == 1 new_body.insert(index, cst.parse_statement(f"import {self.library[0]}")) return updated_node.with_changes(body=new_body) - - -# necessary as we don't know whether to insert checkpoints on the first pass of a loop -# so we transform the loop body afterwards -class InsertCheckpointsInLoopBody(cst.CSTTransformer): - def __init__( - self, nodes_needing_checkpoint: Sequence[cst.Yield | cst.Return], library: str - ): - super().__init__() - self.nodes_needing_checkpoint = nodes_needing_checkpoint - self.add_statement: cst.SimpleStatementLine | None = None - self.library = library - - # insert checkpoint before yield with a flattensentinel, if indicated - # type checkers don't like going across classes, esp as the method accesses - # and modifies self.add_statement, but #YOLO - leave_SimpleStatementLine = Visitor91X.leave_SimpleStatementLine # type: ignore - - def leave_Yield( - self, - original_node: cst.Yield, - updated_node: cst.Yield, - ) -> cst.Yield: - # we need to check *original* node here, since updated node is a copy - # which loses identity equality - if original_node in self.nodes_needing_checkpoint: - self.add_statement = checkpoint_statement(self.library) - return updated_node - - leave_Return = leave_Yield # type: ignore diff --git a/tests/autofix_files/trio100.py b/tests/autofix_files/trio100.py index df7ea155..2027adce 100644 --- a/tests/autofix_files/trio100.py +++ b/tests/autofix_files/trio100.py @@ -1,4 +1,5 @@ # type: ignore +# AUTOFIX import trio @@ -67,18 +68,3 @@ async def foo(): async with random_ignored_library.fail_after(10): ... - - -async def function_name2(): - with ( - open("") as _, - trio.fail_after(10), # error: 8, "trio", "fail_after" - ): - ... - - with ( - trio.fail_after(5), # error: 8, "trio", "fail_after" - open("") as _, - trio.move_on_after(5), # error: 8, "trio", "move_on_after" - ): - ... diff --git a/tests/autofix_files/trio100.py.diff b/tests/autofix_files/trio100.py.diff index ea3f2a9a..1b9ed2c1 100644 --- a/tests/autofix_files/trio100.py.diff +++ b/tests/autofix_files/trio100.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ -2,24 +2,24 @@ +@@ -3,24 +3,24 @@ import trio @@ -38,7 +38,7 @@ with trio.move_on_after(10): await trio.sleep(1) -@@ -36,8 +36,8 @@ +@@ -37,8 +37,8 @@ with open("filename") as _: ... @@ -49,7 +49,7 @@ send_channel, receive_channel = trio.open_memory_channel(0) async with trio.fail_after(10): -@@ -48,22 +48,22 @@ +@@ -49,22 +49,22 @@ async for _ in receive_channel: ... diff --git a/tests/autofix_files/trio100_simple_autofix.py b/tests/autofix_files/trio100_simple_autofix.py index 27dd18b8..7516fe64 100644 --- a/tests/autofix_files/trio100_simple_autofix.py +++ b/tests/autofix_files/trio100_simple_autofix.py @@ -1,3 +1,4 @@ +# AUTOFIX import trio # a @@ -29,13 +30,6 @@ # c # d -# Doesn't autofix With's with multiple withitems -with ( - trio.move_on_after(10), # error: 4, "trio", "move_on_after" - open("") as f, -): - ... - # multiline with, despite only being one statement # a diff --git a/tests/autofix_files/trio100_simple_autofix.py.diff b/tests/autofix_files/trio100_simple_autofix.py.diff index a8096ba0..ac5e82b7 100644 --- a/tests/autofix_files/trio100_simple_autofix.py.diff +++ b/tests/autofix_files/trio100_simple_autofix.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ -2,28 +2,29 @@ +@@ -3,50 +3,51 @@ # a # b @@ -45,7 +45,6 @@ # fmt: on # c # d -@@ -37,22 +38,22 @@ # multiline with, despite only being one statement diff --git a/tests/autofix_files/trio910.py b/tests/autofix_files/trio910.py index 9a122423..a4bd9a32 100644 --- a/tests/autofix_files/trio910.py +++ b/tests/autofix_files/trio910.py @@ -1,3 +1,4 @@ +# AUTOFIX # mypy: disable-error-code="unreachable" import typing from typing import Any, overload diff --git a/tests/autofix_files/trio910.py.diff b/tests/autofix_files/trio910.py.diff index 167684f1..e79b955e 100644 --- a/tests/autofix_files/trio910.py.diff +++ b/tests/autofix_files/trio910.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ -44,12 +44,14 @@ +@@ -45,12 +45,14 @@ async def foo1(): # error: 0, "exit", Statement("function definition", lineno) bar() @@ -15,7 +15,7 @@ async def foo_if_2(): -@@ -80,6 +82,7 @@ +@@ -81,6 +83,7 @@ async def foo_ifexp_2(): # error: 0, "exit", Statement("function definition", lineno) print(_ if False and await foo() else await foo()) @@ -23,7 +23,7 @@ # nested function definition -@@ -88,6 +91,7 @@ +@@ -89,6 +92,7 @@ async def foo_func_2(): # error: 4, "exit", Statement("function definition", lineno) bar() @@ -31,7 +31,7 @@ # we don't get a newline after the nested function definition before the checkpoint -@@ -96,17 +100,21 @@ +@@ -97,17 +101,21 @@ async def foo_func_3(): # error: 0, "exit", Statement("function definition", lineno) async def foo_func_4(): await foo() @@ -53,7 +53,7 @@ # fmt: on -@@ -143,11 +151,13 @@ +@@ -144,11 +152,13 @@ async def foo_condition_2(): # error: 0, "exit", Statement("function definition", lineno) if False and await foo(): ... @@ -67,7 +67,7 @@ async def foo_condition_4(): # safe -@@ -169,6 +179,7 @@ +@@ -170,6 +180,7 @@ async def foo_while_1(): # error: 0, "exit", Statement("function definition", lineno) while _: await foo() @@ -75,7 +75,7 @@ async def foo_while_2(): # now safe -@@ -187,12 +198,14 @@ +@@ -188,12 +199,14 @@ async def foo_while_4(): # error: 0, "exit", Statement("function definition", lineno) while False: await foo() @@ -90,7 +90,7 @@ async def foo_for_2(): # now safe -@@ -215,6 +228,7 @@ +@@ -216,6 +229,7 @@ break else: await foo() @@ -98,7 +98,7 @@ async def foo_while_break_3(): # error: 0, "exit", Statement("function definition", lineno) -@@ -223,6 +237,7 @@ +@@ -224,6 +238,7 @@ break else: ... @@ -106,7 +106,7 @@ async def foo_while_break_4(): # error: 0, "exit", Statement("function definition", lineno) -@@ -230,6 +245,7 @@ +@@ -231,6 +246,7 @@ break else: ... @@ -114,7 +114,7 @@ async def foo_while_continue_1(): # safe -@@ -253,6 +269,7 @@ +@@ -254,6 +270,7 @@ continue else: ... @@ -122,7 +122,7 @@ async def foo_while_continue_4(): # error: 0, "exit", Statement("function definition", lineno) -@@ -260,6 +277,7 @@ +@@ -261,6 +278,7 @@ continue else: ... @@ -130,7 +130,7 @@ async def foo_async_for_1(): -@@ -298,6 +316,7 @@ +@@ -299,6 +317,7 @@ raise else: await foo() @@ -138,7 +138,7 @@ async def foo_try_2(): # safe -@@ -348,6 +367,7 @@ +@@ -349,6 +368,7 @@ pass else: pass @@ -146,7 +146,7 @@ async def foo_try_7(): # safe -@@ -389,6 +409,7 @@ +@@ -390,6 +410,7 @@ await trio.sleep(0) except: ... @@ -154,7 +154,7 @@ # safe -@@ -416,16 +437,19 @@ +@@ -417,16 +438,19 @@ except: ... finally: @@ -174,7 +174,7 @@ return # error: 8, "return", Statement("function definition", lineno-2) await foo() -@@ -434,6 +458,7 @@ +@@ -435,6 +459,7 @@ if _: await foo() return # safe @@ -182,7 +182,7 @@ # loop over non-empty static collection -@@ -461,12 +486,14 @@ +@@ -462,12 +487,14 @@ async def foo_range_4(): # error: 0, "exit", Statement("function definition", lineno) for i in range(10, 5): await foo() @@ -197,7 +197,7 @@ # https://github.com/Zac-HD/flake8-trio/issues/47 -@@ -550,6 +577,7 @@ +@@ -551,6 +578,7 @@ # should error async def foo_comprehension_2(): # error: 0, "exit", Statement("function definition", lineno) [await foo() for x in range(10) if bar()] diff --git a/tests/autofix_files/trio911.py b/tests/autofix_files/trio911.py index e167b67f..a839c5fa 100644 --- a/tests/autofix_files/trio911.py +++ b/tests/autofix_files/trio911.py @@ -1,3 +1,4 @@ +# AUTOFIX from typing import Any import pytest @@ -73,15 +74,6 @@ async def foo_async_with(): yield -# fmt: off -async def foo_async_with_2(): - # with'd expression evaluated before checkpoint - async with (yield): # error: 16, "yield", Statement("function definition", lineno-2) - await trio.lowlevel.checkpoint() - yield -# fmt: on - - async def foo_async_with_3(): async with trio.fail_after(5): yield @@ -90,11 +82,8 @@ async def foo_async_with_3(): # async for -async def foo_async_for(): # error: 0, "exit", Statement("yield", lineno+6) - async for i in ( - yield # error: 8, "yield", Statement("function definition", lineno-2) - ): - await trio.lowlevel.checkpoint() +async def foo_async_for(): # error: 0, "exit", Statement("yield", lineno+4) + async for i in bar(): yield # safe else: yield # safe @@ -713,32 +702,6 @@ async def foo_boolops_1(): # error: 0, "exit", Stmt("yield", line+1) await trio.lowlevel.checkpoint() -# may shortcut after any of the yields -async def foo_boolops_2(): # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+6) - await trio.lowlevel.checkpoint() - # known false positive - but chained yields in bool should be rare - _ = ( - await foo() - and (yield) - and await foo() - and (yield) # error: 13, "yield", Stmt("yield", line-2, 13) - ) - await trio.lowlevel.checkpoint() - - -# fmt: off -async def foo_boolops_3(): # error: 0, "exit", Stmt("yield", line+1) # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+5) - await trio.lowlevel.checkpoint() - _ = (await foo() or (yield) or await foo()) or ( - ... - or ( - (yield) # error: 13, "yield", Stmt("yield", line-3) - and (yield)) # error: 17, "yield", Stmt("yield", line-1) - ) - await trio.lowlevel.checkpoint() -# fmt: on - - # loop over non-empty static collection async def foo_loop_static(): # break/else behaviour on guaranteed body execution diff --git a/tests/autofix_files/trio911.py.diff b/tests/autofix_files/trio911.py.diff index c9b54841..d9c63d51 100644 --- a/tests/autofix_files/trio911.py.diff +++ b/tests/autofix_files/trio911.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ -23,7 +23,9 @@ +@@ -24,7 +24,9 @@ async def foo_yield_2(): @@ -10,7 +10,7 @@ yield # error: 4, "yield", Statement("yield", lineno-1) await foo() -@@ -31,22 +33,29 @@ +@@ -32,22 +34,29 @@ async def foo_yield_3(): # error: 0, "exit", Statement("yield", lineno+2) await foo() yield @@ -41,14 +41,6 @@ @@ -68,6 +77,7 @@ - async def foo_async_with_2(): - # with'd expression evaluated before checkpoint - async with (yield): # error: 16, "yield", Statement("function definition", lineno-2) -+ await trio.lowlevel.checkpoint() - yield - # fmt: on - -@@ -75,6 +85,7 @@ async def foo_async_with_3(): async with trio.fail_after(5): yield @@ -56,11 +48,7 @@ yield # error: 8, "yield", Statement("yield", lineno-1) -@@ -83,9 +94,11 @@ - async for i in ( - yield # error: 8, "yield", Statement("function definition", lineno-2) - ): -+ await trio.lowlevel.checkpoint() +@@ -77,6 +87,7 @@ yield # safe else: yield # safe @@ -68,7 +56,7 @@ # await anext(iter) is not called on break -@@ -94,6 +107,7 @@ +@@ -85,6 +96,7 @@ yield if ...: break @@ -76,7 +64,7 @@ async def foo_async_for_3(): # safe -@@ -111,13 +125,16 @@ +@@ -102,13 +114,16 @@ async def foo_for(): # error: 0, "exit", Statement("yield", lineno+3) await foo() for i in "": @@ -93,7 +81,7 @@ # while -@@ -130,13 +147,16 @@ +@@ -121,13 +136,16 @@ else: await foo() # will always run yield # safe @@ -110,7 +98,7 @@ # no checkpoint after yield if else is entered -@@ -145,39 +165,52 @@ +@@ -136,39 +154,52 @@ await foo() yield else: @@ -163,7 +151,7 @@ await foo() -@@ -187,16 +220,19 @@ +@@ -178,16 +209,19 @@ async def foo_while_continue_1(): # error: 0, "exit", Statement("yield", lineno+3) await foo() while foo(): @@ -183,7 +171,7 @@ yield # error: 8, "yield", Statement("yield", lineno) if foo(): continue -@@ -206,6 +242,7 @@ +@@ -197,6 +231,7 @@ while foo(): yield # safe await foo() @@ -191,7 +179,7 @@ # --- while + break --- -@@ -216,7 +253,9 @@ +@@ -207,7 +242,9 @@ break else: await foo() @@ -201,7 +189,7 @@ # no checkpoint on break -@@ -227,6 +266,7 @@ +@@ -218,6 +255,7 @@ if ...: break await foo() @@ -209,7 +197,7 @@ # guaranteed if else and break -@@ -238,6 +278,7 @@ +@@ -229,6 +267,7 @@ else: await foo() # runs if 0-iter yield # safe @@ -217,7 +205,7 @@ # break at non-guaranteed checkpoint -@@ -248,7 +289,9 @@ +@@ -239,7 +278,9 @@ await foo() # might not run else: await foo() # might not run @@ -227,7 +215,7 @@ # check break is reset on nested -@@ -264,7 +307,9 @@ +@@ -255,7 +296,9 @@ await foo() yield # safe await foo() @@ -237,7 +225,7 @@ # check multiple breaks -@@ -279,7 +324,9 @@ +@@ -270,7 +313,9 @@ await foo() if ...: break @@ -247,7 +235,7 @@ async def foo_while_break_7(): # error: 0, "exit", Statement("function definition", lineno)# error: 0, "exit", Statement("yield", lineno+5) -@@ -289,6 +336,7 @@ +@@ -280,6 +325,7 @@ break yield break @@ -255,7 +243,7 @@ async def foo_while_endless_1(): -@@ -301,6 +349,7 @@ +@@ -292,6 +338,7 @@ while foo(): await foo() yield @@ -263,7 +251,7 @@ async def foo_while_endless_3(): -@@ -322,9 +371,11 @@ +@@ -313,9 +360,11 @@ # try async def foo_try_1(): # error: 0, "exit", Statement("function definition", lineno) # error: 0, "exit", Statement("yield", lineno+2) try: @@ -275,7 +263,7 @@ # no checkpoint after yield in ValueError -@@ -332,12 +383,14 @@ +@@ -323,12 +372,14 @@ try: await foo() except ValueError: @@ -290,7 +278,7 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6) -@@ -346,13 +399,16 @@ +@@ -337,13 +388,16 @@ except: await foo() else: @@ -307,7 +295,7 @@ yield # error: 8, "yield", Statement("function definition", lineno-4) finally: await foo() -@@ -362,6 +418,7 @@ +@@ -353,6 +407,7 @@ try: await foo() finally: @@ -315,7 +303,7 @@ # try might crash before checkpoint yield # error: 8, "yield", Statement("function definition", lineno-5) await foo() -@@ -372,7 +429,9 @@ +@@ -363,7 +418,9 @@ await foo() except ValueError: pass @@ -325,7 +313,7 @@ async def foo_try_7(): # error: 0, "exit", Statement("yield", lineno+17) -@@ -385,6 +444,7 @@ +@@ -376,6 +433,7 @@ yield await foo() except SyntaxError: @@ -333,7 +321,7 @@ yield # error: 8, "yield", Statement("yield", lineno-7) await foo() finally: -@@ -393,6 +453,7 @@ +@@ -384,6 +442,7 @@ # by any of the excepts, jumping straight to the finally. # Then the error will be propagated upwards yield # safe @@ -341,7 +329,7 @@ ## safe only if (try or else) and all except bodies either await or raise -@@ -408,6 +469,7 @@ +@@ -399,6 +458,7 @@ raise else: await foo() @@ -349,7 +337,7 @@ # no checkpoint after yield in else -@@ -418,6 +480,7 @@ +@@ -409,6 +469,7 @@ await foo() else: yield @@ -357,7 +345,7 @@ # bare except means we'll jump to finally after full execution of either try or the except -@@ -448,6 +511,7 @@ +@@ -439,6 +500,7 @@ except ValueError: await foo() finally: @@ -365,7 +353,7 @@ yield # error: 8, "yield", Statement("function definition", lineno-6) await foo() -@@ -456,6 +520,7 @@ +@@ -447,6 +509,7 @@ try: await foo() finally: @@ -373,7 +361,7 @@ # try might crash before checkpoint yield # error: 8, "yield", Statement("function definition", lineno-5) await foo() -@@ -464,9 +529,11 @@ +@@ -455,9 +518,11 @@ # if async def foo_if_1(): if ...: @@ -385,7 +373,7 @@ yield # error: 8, "yield", Statement("function definition", lineno-5) await foo() -@@ -477,7 +544,9 @@ +@@ -468,7 +533,9 @@ ... else: yield @@ -395,7 +383,7 @@ async def foo_if_3(): # error: 0, "exit", Statement("yield", lineno+6) -@@ -486,7 +555,9 @@ +@@ -477,7 +544,9 @@ yield else: ... @@ -405,7 +393,7 @@ async def foo_if_4(): # error: 0, "exit", Statement("yield", lineno+7) -@@ -496,7 +567,9 @@ +@@ -487,7 +556,9 @@ await foo() else: ... @@ -415,7 +403,7 @@ async def foo_if_5(): # error: 0, "exit", Statement("yield", lineno+8) -@@ -507,7 +580,9 @@ +@@ -498,7 +569,9 @@ else: yield ... @@ -425,7 +413,7 @@ async def foo_if_6(): # error: 0, "exit", Statement("yield", lineno+8) -@@ -518,7 +593,9 @@ +@@ -509,7 +582,9 @@ yield await foo() ... @@ -435,7 +423,7 @@ async def foo_if_7(): # error: 0, "exit", Statement("function definition", lineno) -@@ -526,6 +603,7 @@ +@@ -517,6 +592,7 @@ await foo() yield await foo() @@ -443,7 +431,7 @@ async def foo_if_8(): # error: 0, "exit", Statement("function definition", lineno) -@@ -535,21 +613,25 @@ +@@ -526,21 +602,25 @@ await foo() yield await foo() @@ -469,7 +457,7 @@ # normal function -@@ -594,7 +676,9 @@ +@@ -585,7 +665,9 @@ await foo() async def foo_func_2(): # error: 4, "exit", Statement("yield", lineno+1) @@ -479,7 +467,7 @@ # autofix doesn't insert newline after nested function def and before checkpoint -@@ -606,6 +690,7 @@ +@@ -597,6 +679,7 @@ async def foo_func_4(): await foo() @@ -487,7 +475,7 @@ async def foo_func_5(): # error: 0, "exit", Statement("yield", lineno+2) -@@ -618,16 +703,19 @@ +@@ -609,12 +692,14 @@ async def foo_func_7(): await foo() ... @@ -501,33 +489,8 @@ + await trio.lowlevel.checkpoint() - # may shortcut after any of the yields - async def foo_boolops_2(): # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+6) -+ await trio.lowlevel.checkpoint() - # known false positive - but chained yields in bool should be rare - _ = ( - await foo() -@@ -635,16 +723,19 @@ - and await foo() - and (yield) # error: 13, "yield", Stmt("yield", line-2, 13) - ) -+ await trio.lowlevel.checkpoint() - - - # fmt: off - async def foo_boolops_3(): # error: 0, "exit", Stmt("yield", line+1) # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+5) -+ await trio.lowlevel.checkpoint() - _ = (await foo() or (yield) or await foo()) or ( - ... - or ( - (yield) # error: 13, "yield", Stmt("yield", line-3) - and (yield)) # error: 17, "yield", Stmt("yield", line-1) - ) -+ await trio.lowlevel.checkpoint() - # fmt: on - - -@@ -672,6 +763,7 @@ + # loop over non-empty static collection +@@ -641,6 +726,7 @@ if ...: continue await foo() @@ -535,7 +498,7 @@ yield # error: 4, "yield", Stmt("yield", line-7) # continue/else -@@ -680,6 +772,7 @@ +@@ -649,6 +735,7 @@ continue await foo() else: @@ -543,7 +506,7 @@ yield # error: 8, "yield", Stmt("yield", line-8) await foo() yield -@@ -724,6 +817,7 @@ +@@ -693,6 +780,7 @@ for _ in (): await foo() @@ -551,7 +514,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) for _ in {1: 2, 3: 4}: -@@ -732,14 +826,17 @@ +@@ -701,14 +789,17 @@ for _ in " ".strip(): await foo() @@ -569,7 +532,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) for _ in (*(1, 2),): -@@ -748,10 +845,12 @@ +@@ -717,10 +808,12 @@ for _ in {**{}}: await foo() @@ -582,7 +545,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) for _ in {**{1: 2}}: -@@ -777,31 +876,38 @@ +@@ -746,31 +839,38 @@ for _ in {}: await foo() @@ -621,7 +584,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) # while -@@ -815,6 +921,7 @@ +@@ -784,6 +884,7 @@ if ...: break await foo() @@ -629,7 +592,7 @@ yield # error: 4, "yield", Stmt("yield", line-6) while True: -@@ -827,6 +934,7 @@ +@@ -796,6 +897,7 @@ while False: await foo() # type: ignore[unreachable] @@ -637,7 +600,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) while "hello": -@@ -836,11 +944,13 @@ +@@ -805,11 +907,13 @@ # false positive on containers while [1, 2]: await foo() @@ -651,7 +614,7 @@ yield # error: 4, "yield", Stmt("yield", line-5) # range with constant arguments also handled, see more extensive tests in 910 -@@ -858,10 +968,12 @@ +@@ -827,10 +931,12 @@ for i in range(1 + 1): # not handled await foo() @@ -664,7 +627,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) for i in range(+3): -@@ -870,6 +982,7 @@ +@@ -839,6 +945,7 @@ for i in range(-3.5): # type: ignore await foo() @@ -672,7 +635,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) # duplicated from 910 to have all range tests in one place -@@ -893,20 +1006,24 @@ +@@ -862,20 +969,24 @@ for i in range(10, 5): await foo() @@ -697,7 +660,7 @@ yield # error: 4, "yield", Stmt("yield", line-5) await foo() -@@ -930,6 +1047,7 @@ +@@ -899,6 +1010,7 @@ # guaranteed iteration and await in value, but test is not guaranteed [await foo() for x in range(10) if bar()] @@ -705,7 +668,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) # guaranteed iteration and await in value -@@ -938,6 +1056,7 @@ +@@ -907,6 +1019,7 @@ # not guaranteed to iter [await foo() for x in bar()] @@ -713,7 +676,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) # await statement in loop expression -@@ -949,10 +1068,12 @@ +@@ -918,10 +1031,12 @@ yield # safe {await foo() for x in bar()} @@ -726,7 +689,7 @@ yield # error: 4, "yield", Stmt("yield", line-4) # other than `await` can be in both key&val -@@ -964,9 +1085,11 @@ +@@ -933,9 +1048,11 @@ # generator expressions are never treated as safe (await foo() for x in range(10)) @@ -738,7 +701,7 @@ yield # error: 4, "yield", Stmt("yield", line-3) # async for always safe -@@ -979,27 +1102,33 @@ +@@ -948,27 +1065,33 @@ # other than in generator expression (... async for x in bar()) @@ -772,7 +735,7 @@ yield # error: 4, "yield", Stmt("yield", line-2) # multiple ifs -@@ -1007,6 +1136,7 @@ +@@ -976,6 +1099,7 @@ yield [... for x in range(10) for y in bar() if await foo() if await foo()] diff --git a/tests/autofix_files/trio91x_autofix.py b/tests/autofix_files/trio91x_autofix.py index cb4868ba..a0bc3d06 100644 --- a/tests/autofix_files/trio91x_autofix.py +++ b/tests/autofix_files/trio91x_autofix.py @@ -1,3 +1,4 @@ +# AUTOFIX from __future__ import annotations """Docstring for file @@ -107,25 +108,21 @@ async def bar(): await foo() -# not handled, but at least doesn't insert an unnecessary checkpoint -async def foo_singleline(): - await foo() - # fmt: off - yield; yield # TRIO911: 11, "yield", Statement("yield", lineno, 4) - # fmt: on - await foo() - - -# not autofixed -async def foo_singleline2(): - # fmt: off - yield; await foo() # TRIO911: 4, "yield", Statement("function definition", lineno-2) - # fmt: on +# Code coverage: visitors run when inside a sync function that has an async function. +# When sync funcs don't contain an async func the body is not visited. +def sync_func(): + async def async_func(): + ... - -# not autofixed -async def foo_singleline3(): - # fmt: off - if ...: yield # TRIO911: 12, "yield", Statement("function definition", lineno-2) - # fmt: on - await foo() + try: + ... + except: + ... + if ... and ...: + ... + while ...: + if ...: + continue + break + [... for i in range(5)] + return diff --git a/tests/autofix_files/trio91x_autofix.py.diff b/tests/autofix_files/trio91x_autofix.py.diff index 65d90259..2ab4724d 100644 --- a/tests/autofix_files/trio91x_autofix.py.diff +++ b/tests/autofix_files/trio91x_autofix.py.diff @@ -1,6 +1,6 @@ --- +++ -@@ -8,6 +8,7 @@ +@@ -9,6 +9,7 @@ # ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911) from typing import Any @@ -8,7 +8,7 @@ def bar() -> Any: -@@ -20,30 +21,38 @@ +@@ -21,30 +22,38 @@ async def foo1(): # TRIO910: 0, "exit", Statement("function definition", lineno) bar() @@ -47,7 +47,7 @@ yield # TRIO911: 8, "yield", Statement("yield", lineno) -@@ -66,8 +75,10 @@ +@@ -67,8 +76,10 @@ async def foo_while4(): while True: if ...: @@ -58,7 +58,7 @@ yield # TRIO911: 12, "yield", Statement("yield", lineno) # TRIO911: 12, "yield", Statement("yield", lineno-2) # TRIO911: 12, "yield", Statement("function definition", lineno-5) # TRIO911: 12, "yield", Statement("yield", lineno-2) # this warns about the yield on lineno-2 twice, since it can arrive here from it in two different ways -@@ -75,15 +86,19 @@ +@@ -76,15 +87,19 @@ # check state management of nested loops async def foo_nested_while(): while True: diff --git a/tests/eval_files/trio100.py b/tests/eval_files/trio100.py index b63cf5ee..d22fc9c7 100644 --- a/tests/eval_files/trio100.py +++ b/tests/eval_files/trio100.py @@ -1,4 +1,5 @@ # type: ignore +# AUTOFIX import trio @@ -67,18 +68,3 @@ async def foo(): async with random_ignored_library.fail_after(10): ... - - -async def function_name2(): - with ( - open("") as _, - trio.fail_after(10), # error: 8, "trio", "fail_after" - ): - ... - - with ( - trio.fail_after(5), # error: 8, "trio", "fail_after" - open("") as _, - trio.move_on_after(5), # error: 8, "trio", "move_on_after" - ): - ... diff --git a/tests/eval_files/trio100_noautofix.py b/tests/eval_files/trio100_noautofix.py new file mode 100644 index 00000000..f16d77b1 --- /dev/null +++ b/tests/eval_files/trio100_noautofix.py @@ -0,0 +1,24 @@ +import trio + + +# Doesn't autofix With's with multiple withitems +async def function_name2(): + with ( + open("") as _, + trio.fail_after(10), # error: 8, "trio", "fail_after" + ): + ... + + with ( + trio.fail_after(5), # error: 8, "trio", "fail_after" + open("") as _, + trio.move_on_after(5), # error: 8, "trio", "move_on_after" + ): + ... + + +with ( + trio.move_on_after(10), # error: 4, "trio", "move_on_after" + open("") as f, +): + ... diff --git a/tests/eval_files/trio100_simple_autofix.py b/tests/eval_files/trio100_simple_autofix.py index 46655c3d..2ddc609d 100644 --- a/tests/eval_files/trio100_simple_autofix.py +++ b/tests/eval_files/trio100_simple_autofix.py @@ -1,3 +1,4 @@ +# AUTOFIX import trio # a @@ -28,13 +29,6 @@ # c # d -# Doesn't autofix With's with multiple withitems -with ( - trio.move_on_after(10), # error: 4, "trio", "move_on_after" - open("") as f, -): - ... - # multiline with, despite only being one statement with ( # a diff --git a/tests/eval_files/trio910.py b/tests/eval_files/trio910.py index 10c8126d..32eed8ad 100644 --- a/tests/eval_files/trio910.py +++ b/tests/eval_files/trio910.py @@ -1,3 +1,4 @@ +# AUTOFIX # mypy: disable-error-code="unreachable" import typing from typing import Any, overload diff --git a/tests/eval_files/trio911.py b/tests/eval_files/trio911.py index 56a33ccd..746c5562 100644 --- a/tests/eval_files/trio911.py +++ b/tests/eval_files/trio911.py @@ -1,3 +1,4 @@ +# AUTOFIX from typing import Any import pytest @@ -64,14 +65,6 @@ async def foo_async_with(): yield -# fmt: off -async def foo_async_with_2(): - # with'd expression evaluated before checkpoint - async with (yield): # error: 16, "yield", Statement("function definition", lineno-2) - yield -# fmt: on - - async def foo_async_with_3(): async with trio.fail_after(5): yield @@ -79,10 +72,8 @@ async def foo_async_with_3(): # async for -async def foo_async_for(): # error: 0, "exit", Statement("yield", lineno+6) - async for i in ( - yield # error: 8, "yield", Statement("function definition", lineno-2) - ): +async def foo_async_for(): # error: 0, "exit", Statement("yield", lineno+4) + async for i in bar(): yield # safe else: yield # safe @@ -626,28 +617,6 @@ async def foo_boolops_1(): # error: 0, "exit", Stmt("yield", line+1) _ = await foo() and (yield) and await foo() -# may shortcut after any of the yields -async def foo_boolops_2(): # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+6) - # known false positive - but chained yields in bool should be rare - _ = ( - await foo() - and (yield) - and await foo() - and (yield) # error: 13, "yield", Stmt("yield", line-2, 13) - ) - - -# fmt: off -async def foo_boolops_3(): # error: 0, "exit", Stmt("yield", line+1) # error: 0, "exit", Stmt("yield", line+4) # error: 0, "exit", Stmt("yield", line+5) - _ = (await foo() or (yield) or await foo()) or ( - ... - or ( - (yield) # error: 13, "yield", Stmt("yield", line-3) - and (yield)) # error: 17, "yield", Stmt("yield", line-1) - ) -# fmt: on - - # loop over non-empty static collection async def foo_loop_static(): # break/else behaviour on guaranteed body execution diff --git a/tests/eval_files/trio91x_autofix.py b/tests/eval_files/trio91x_autofix.py index c1f09607..ca46596f 100644 --- a/tests/eval_files/trio91x_autofix.py +++ b/tests/eval_files/trio91x_autofix.py @@ -1,3 +1,4 @@ +# AUTOFIX from __future__ import annotations """Docstring for file @@ -92,25 +93,21 @@ async def bar(): await foo() -# not handled, but at least doesn't insert an unnecessary checkpoint -async def foo_singleline(): - await foo() - # fmt: off - yield; yield # TRIO911: 11, "yield", Statement("yield", lineno, 4) - # fmt: on - await foo() - - -# not autofixed -async def foo_singleline2(): - # fmt: off - yield; await foo() # TRIO911: 4, "yield", Statement("function definition", lineno-2) - # fmt: on +# Code coverage: visitors run when inside a sync function that has an async function. +# When sync funcs don't contain an async func the body is not visited. +def sync_func(): + async def async_func(): + ... - -# not autofixed -async def foo_singleline3(): - # fmt: off - if ...: yield # TRIO911: 12, "yield", Statement("function definition", lineno-2) - # fmt: on - await foo() + try: + ... + except: + ... + if ... and ...: + ... + while ...: + if ...: + continue + break + [... for i in range(5)] + return diff --git a/tests/eval_files/trio91x_noautofix.py b/tests/eval_files/trio91x_noautofix.py new file mode 100644 index 00000000..1c7eee16 --- /dev/null +++ b/tests/eval_files/trio91x_noautofix.py @@ -0,0 +1,71 @@ +# ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911) +from typing import Any + + +async def foo() -> Any: + await foo() + + +# not handled, but at least doesn't insert an unnecessary checkpoint +async def foo_singleline(): + await foo() + # fmt: off + yield; yield # TRIO911: 11, "yield", Statement("yield", lineno, 4) + # fmt: on + await foo() + + +# not autofixed +async def foo_singleline2(): + # fmt: off + yield; await foo() # TRIO911: 4, "yield", Statement("function definition", lineno-2) + # fmt: on + + +# not autofixed +async def foo_singleline3(): + # fmt: off + if ...: yield # TRIO911: 12, "yield", Statement("function definition", lineno-2) + # fmt: on + await foo() + + +# fmt: off +async def foo_async_with_2(): + # with'd expression evaluated before checkpoint + async with (yield): # TRIO911: 16, "yield", Statement("function definition", lineno-2) + yield +# fmt: on + +# fmt: off +async def foo_boolops_3(): + _ = (await foo() or (yield) or await foo()) or ( + ... + or ( + (yield) # TRIO911: 13, "yield", Stmt("yield", line-3) + and (yield)) # TRIO911: 17, "yield", Stmt("yield", line-1) + ) + await foo() +# fmt: on + + +async def foo_async_for(): + async for i in ( + yield # TRIO911: 8, "yield", Statement("function definition", lineno-2) + ): + yield # safe + else: + yield # safe + await foo() + + +# may shortcut after any of the yields +async def foo_boolops_2(): + # known false positive - but chained yields in bool should be rare + _ = ( + await foo() + and (yield) + and await foo() + and (yield) # TRIO911: 13, "yield", Stmt("yield", line-2, 13) + ) + await foo() diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 94aca2ff..35e51ecf 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -32,14 +32,13 @@ from flake8_trio.visitors.flake8triovisitor import Flake8TrioVisitor +AUTOFIX_DIR = Path(__file__).parent / "autofix_files" test_files: list[tuple[str, Path]] = sorted( (f.stem.upper(), f) for f in (Path(__file__).parent / "eval_files").iterdir() ) autofix_files: dict[str, Path] = { - f.stem.upper(): f - for f in (Path(__file__).parent / "autofix_files").iterdir() - if f.suffix == ".py" + f.stem.upper(): f for f in AUTOFIX_DIR.iterdir() if f.suffix == ".py" } # check that there's an eval file for each autofix file assert set(autofix_files.keys()) - {f[0] for f in test_files} == set() @@ -110,18 +109,22 @@ def check_autofix( generate_autofix: bool, anyio: bool = False, ): - if test not in autofix_files: - return # the source code after it's been visited by current transformers visited_code = plugin.module.code + + if "# AUTOFIX" not in unfixed_code: + assert unfixed_code == visited_code + return + # the full generated source code, saved from a previous run + if test not in autofix_files: + autofix_files[test] = AUTOFIX_DIR / (test.lower() + ".py") + autofix_files[test].write_text("") previous_autofixed = autofix_files[test].read_text() # file contains a previous diff showing what's added/removed by the autofixer # i.e. a diff between "eval_files/{test}.py" and "autofix_files/{test}.py" - autofix_diff_file = ( - Path(__file__).parent / "autofix_files" / f"{test.lower()}.py.diff" - ) + autofix_diff_file = AUTOFIX_DIR / f"{test.lower()}.py.diff" if not autofix_diff_file.exists(): assert generate_autofix, "autofix diff file doesn't exist" # if generate_autofix is set, the diff content isn't used and the file @@ -219,6 +222,26 @@ def test_eval_anyio(test: str, path: Path, generate_autofix: bool): check_autofix(test, plugin, content, generate_autofix, anyio=True) +# check that autofixed files raise no errors and doesn't get autofixed (again) +@pytest.mark.parametrize("test", autofix_files) +def test_autofix(test: str): + content = autofix_files[test].read_text() + if "# NOTRIO" in content: + pytest.skip("file marked with NOTRIO") + + _, parsed_args = _parse_eval_file(test, content) + parsed_args.append("--autofix") + + plugin = Plugin.from_source(content) + # not passing any expected errors + _ = assert_expected_errors(plugin, args=parsed_args) + + diff = diff_strings(plugin.module.code, content) + if diff: + print(diff) + assert plugin.module.code == content, "autofixed file changed when autofixed again" + + def _parse_eval_file(test: str, content: str) -> tuple[list[Error], list[str]]: # version check check_version(test) @@ -298,6 +321,7 @@ def _parse_eval_file(test: str, content: str) -> tuple[list[Error], list[str]]: ) raise ParseError(msg) from e + assert visitor_codes_regex, "no visitors enabled" for error in expected: assert re.match( visitor_codes_regex, error.code