Skip to content

Commit e7e4a95

Browse files
committed
Implement basic autofix infrastructure and autofixer for TRIO100
1 parent 881b16a commit e7e4a95

File tree

13 files changed

+320
-31
lines changed

13 files changed

+320
-31
lines changed

flake8_trio/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def main():
101101
cwd=root,
102102
).stdout.splitlines()
103103
except (subprocess.SubprocessError, FileNotFoundError):
104-
print("Doesn't seem to be a git repo; pass filenames to format.")
104+
print(
105+
"Doesn't seem to be a git repo; pass filenames to format.",
106+
file=sys.stderr,
107+
)
105108
sys.exit(1)
106109
all_filenames = [
107110
os.path.join(root, f) for f in all_filenames if _should_format(f)
@@ -110,6 +113,9 @@ def main():
110113
plugin = Plugin.from_filename(file)
111114
for error in sorted(plugin.run()):
112115
print(f"{file}:{error}")
116+
if plugin.options.autofix:
117+
with open(file, "w") as file:
118+
file.write(plugin.module.code)
113119

114120

115121
class Plugin:
@@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
122128
self._tree = tree
123129
source = "".join(lines)
124130

125-
self._module: cst.Module = cst_parse_module_native(source)
131+
self.module: cst.Module = cst_parse_module_native(source)
126132

127133
@classmethod
128134
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
@@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
137143
plugin = Plugin.__new__(cls)
138144
super(Plugin, plugin).__init__()
139145
plugin._tree = ast.parse(source)
140-
plugin._module = cst_parse_module_native(source)
146+
plugin.module = cst_parse_module_native(source)
141147
return plugin
142148

143149
def run(self) -> Iterable[Error]:
144150
yield from Flake8TrioRunner.run(self._tree, self.options)
145-
yield from Flake8TrioRunner_cst(self.options).run(self._module)
151+
cst_runner = Flake8TrioRunner_cst(self.options, self.module)
152+
yield from cst_runner.run()
153+
self.module = cst_runner.module
146154

147155
@staticmethod
148156
def add_options(option_manager: OptionManager | ArgumentParser):
@@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
157165
add_argument = functools.partial(
158166
option_manager.add_option, parse_from_config=True
159167
)
168+
add_argument("--autofix", action="store_true", required=False)
160169

161170
add_argument(
162171
"--no-checkpoint-warning-decorators",

flake8_trio/runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,20 @@ def visit(self, node: ast.AST):
100100

101101

102102
class Flake8TrioRunner_cst:
103-
def __init__(self, options: Namespace):
103+
def __init__(self, options: Namespace, module: Module):
104104
super().__init__()
105105
self.state = SharedState(options)
106106
self.options = options
107107
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
108108
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
109109
)
110+
self.module = module
110111

111-
def run(self, module: Module) -> Iterable[Error]:
112+
def run(self) -> Iterable[Error]:
112113
if not self.visitors:
113114
return
114-
wrapper = cst.MetadataWrapper(module)
115115
for v in self.visitors:
116-
_ = wrapper.visit(v)
116+
self.module = cst.MetadataWrapper(self.module).visit(v)
117117
yield from self.state.problems
118118

119119
def selected(self, error_codes: dict[str, str]) -> bool:

flake8_trio/visitors/helpers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,38 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
341341
),
342342
)
343343
)
344+
345+
346+
# used in TRIO100
347+
def flatten_preserving_comments(node: cst.BaseCompoundStatement):
348+
if isinstance(node.body, cst.SimpleStatementSuite):
349+
# `with ...: pass;pass;pass` -> pass;pass;pass
350+
return cst.SimpleStatementLine(node.body.body, leading_lines=node.leading_lines)
351+
352+
assert isinstance(node.body, cst.IndentedBlock)
353+
nodes = list(node.body.body)
354+
355+
# nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
356+
# and BaseCompoundStatement - both of which has leading_lines
357+
assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))
358+
359+
# add leading lines of the original node to the leading lines
360+
# of the first statement in the body
361+
new_leading_lines = list(node.leading_lines)
362+
if node.body.header and node.body.header.comment:
363+
new_leading_lines.append(
364+
cst.EmptyLine(indent=True, comment=node.body.header.comment)
365+
)
366+
new_leading_lines.extend(nodes[0].leading_lines)
367+
nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)
368+
369+
# if there's comments in the footer of the indented block, add a pass
370+
# statement with the comments as leading lines
371+
if node.body.footer:
372+
nodes.append(
373+
cst.SimpleStatementLine(
374+
[cst.Pass()],
375+
node.body.footer,
376+
)
377+
)
378+
return cst.FlattenSentinel(nodes)

flake8_trio/visitors/visitor100.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,21 @@
77
"""
88
from __future__ import annotations
99

10-
from typing import Any
10+
from typing import TYPE_CHECKING, Any
1111

1212
import libcst as cst
1313
import libcst.matchers as m
1414

1515
from .flake8triovisitor import Flake8TrioVisitor_cst
16-
from .helpers import AttributeCall, error_class_cst, with_has_call
16+
from .helpers import (
17+
AttributeCall,
18+
error_class_cst,
19+
flatten_preserving_comments,
20+
with_has_call,
21+
)
22+
23+
if TYPE_CHECKING:
24+
pass
1725

1826

1927
@error_class_cst
@@ -46,12 +54,16 @@ def visit_With(self, node: cst.With) -> None:
4654
else:
4755
self.has_checkpoint_stack.append(True)
4856

49-
def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
57+
def leave_With(
58+
self, original_node: cst.With, updated_node: cst.With
59+
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
5060
if not self.has_checkpoint_stack.pop():
5161
for res in self.node_dict[original_node]:
5262
self.error(res.node, res.base, res.function)
53-
# if: autofixing is enabled for this code
54-
# then: remove the with and pop out it's body
63+
64+
if self.options.autofix and len(updated_node.items) == 1:
65+
return flatten_preserving_comments(updated_node)
66+
5567
return updated_node
5668

5769
def visit_For(self, node: cst.For):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.pyright]
22
strict = ["*.py", "tests/*.py", "flake8_trio/**/*.py"]
3-
exclude = ["**/node_modules", "**/__pycache__", "**/.*"]
3+
exclude = ["**/node_modules", "**/__pycache__", "**/.*", "tests/autofix_files/*"]
44
reportUnusedCallResult=false
55
reportUninitializedInstanceVariable=true
66
reportPropertyTypeMismatch=true

tests/autofix_files/trio100.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# type: ignore
2+
3+
import trio
4+
5+
# error: 5, "trio", "move_on_after"
6+
...
7+
8+
9+
async def function_name():
10+
# fmt: off
11+
...; ...; ...
12+
# fmt: on
13+
# error: 15, "trio", "fail_after"
14+
...
15+
# error: 15, "trio", "fail_at"
16+
...
17+
# error: 15, "trio", "move_on_after"
18+
...
19+
# error: 15, "trio", "move_on_at"
20+
...
21+
# error: 15, "trio", "CancelScope"
22+
...
23+
24+
with trio.move_on_after(10):
25+
await trio.sleep(1)
26+
27+
with trio.move_on_after(10):
28+
await trio.sleep(1)
29+
print("hello")
30+
31+
with trio.move_on_after(10):
32+
while True:
33+
await trio.sleep(1)
34+
print("hello")
35+
36+
with open("filename") as _:
37+
...
38+
39+
# error: 9, "trio", "fail_after"
40+
...
41+
42+
send_channel, receive_channel = trio.open_memory_channel(0)
43+
async with trio.fail_after(10):
44+
async with send_channel:
45+
...
46+
47+
async with trio.fail_after(10):
48+
async for _ in receive_channel:
49+
...
50+
51+
# error: 15, "trio", "fail_after"
52+
for _ in receive_channel:
53+
...
54+
55+
# fix missed alarm when function is defined inside the with scope
56+
# error: 9, "trio", "move_on_after"
57+
58+
async def foo():
59+
await trio.sleep(1)
60+
61+
# error: 9, "trio", "move_on_after"
62+
if ...:
63+
64+
async def foo():
65+
if ...:
66+
await trio.sleep(1)
67+
68+
async with random_ignored_library.fail_after(10):
69+
...
70+
71+
72+
async def function_name2():
73+
with (
74+
open("") as _,
75+
trio.fail_after(10), # error: 8, "trio", "fail_after"
76+
):
77+
...
78+
79+
with (
80+
trio.fail_after(5), # error: 8, "trio", "fail_after"
81+
open("") as _,
82+
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
83+
):
84+
...
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import trio
2+
3+
# a
4+
# b
5+
# error: 5, "trio", "move_on_after"
6+
# c
7+
# d
8+
print(1) # e
9+
# f
10+
# g
11+
print(2) # h
12+
# i
13+
# j
14+
print(3) # k
15+
# l
16+
# m
17+
pass
18+
# n
19+
20+
# error: 5, "trio", "move_on_after"
21+
...
22+
23+
24+
# a
25+
# b
26+
# fmt: off
27+
...;...;...
28+
# fmt: on
29+
# c
30+
# d
31+
32+
# Doesn't autofix With's with multiple withitems
33+
with (
34+
trio.move_on_after(10), # error: 4, "trio", "move_on_after"
35+
open("") as f,
36+
):
37+
...
38+
39+
40+
# extreme case I'm not gonna care about, i.e. one item in the with, but it's multiline.
41+
# Only these leading comments, and the last one, are kept, the rest are lost.
42+
# this comment is kept
43+
...

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
99
parser.addoption(
1010
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
1111
)
12+
parser.addoption(
13+
"--generate-autofix",
14+
action="store_true",
15+
default=False,
16+
help="generate autofix file content",
17+
)
1218
parser.addoption(
1319
"--enable-visitor-codes-regex",
1420
default=".*",
@@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
3238
item.add_marker(skip_fuzz)
3339

3440

41+
@pytest.fixture()
42+
def generate_autofix(request: pytest.FixtureRequest):
43+
return request.config.getoption("generate_autofix")
44+
45+
3546
@pytest.fixture()
3647
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
3748
return request.config.getoption("--enable-visitor-codes-regex")

0 commit comments

Comments
 (0)