Skip to content

Commit 4eb714c

Browse files
committed
add args
1 parent 112b998 commit 4eb714c

28 files changed

+359
-171
lines changed

flake8_trio/__init__.py

Lines changed: 108 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import functools
1616
import keyword
1717
import os
18-
import re
1918
import subprocess
2019
import sys
2120
import tokenize
@@ -24,8 +23,9 @@
2423

2524
import libcst as cst
2625

26+
from .base import Options
2727
from .runner import Flake8TrioRunner, Flake8TrioRunner_cst
28-
from .visitors import default_disabled_error_codes
28+
from .visitors import ERROR_CLASSES, ERROR_CLASSES_CST, default_disabled_error_codes
2929

3030
if TYPE_CHECKING:
3131
from collections.abc import Iterable, Sequence
@@ -76,12 +76,6 @@ def cst_parse_module_native(source: str) -> cst.Module:
7676

7777
def main() -> int:
7878
parser = ArgumentParser(prog="flake8_trio")
79-
parser.add_argument(
80-
nargs="*",
81-
metavar="file",
82-
dest="files",
83-
help="Files(s) to format, instead of autodetection.",
84-
)
8579
Plugin.add_options(parser)
8680
args = parser.parse_args()
8781
Plugin.parse_options(args)
@@ -124,7 +118,13 @@ def main() -> int:
124118
class Plugin:
125119
name = __name__
126120
version = __version__
127-
options: Namespace = Namespace()
121+
standalone = True
122+
_options: Options | None = None
123+
124+
@property
125+
def options(self) -> Options:
126+
assert self._options is not None
127+
return self._options
128128

129129
def __init__(self, tree: ast.AST, lines: Sequence[str]):
130130
super().__init__()
@@ -158,18 +158,64 @@ def run(self) -> Iterable[Error]:
158158
@staticmethod
159159
def add_options(option_manager: OptionManager | ArgumentParser):
160160
if isinstance(option_manager, ArgumentParser):
161-
# TODO: disable TRIO9xx calls by default
162-
# if run as standalone
161+
Plugin.standalone = True
163162
add_argument = option_manager.add_argument
163+
add_argument(
164+
nargs="*",
165+
metavar="file",
166+
dest="files",
167+
help="Files(s) to format, instead of autodetection.",
168+
)
164169
else: # if run as a flake8 plugin
170+
Plugin.standalone = False
165171
# Disable TRIO9xx calls by default
166172
option_manager.extend_default_ignore(default_disabled_error_codes)
167173
# add parameter to parse from flake8 config
168174
add_argument = functools.partial( # type: ignore
169175
option_manager.add_option, parse_from_config=True
170176
)
171-
add_argument("--autofix", action="store_true", required=False)
172177

178+
add_argument(
179+
"--enable",
180+
type=comma_separated_list,
181+
default="TRIO",
182+
required=False,
183+
help=(
184+
"Comma-separated list of error codes to enable, similar to flake8"
185+
" --select but is additionally more performant as it will disable"
186+
" non-enabled visitors from running instead of just silencing their"
187+
" errors."
188+
),
189+
)
190+
add_argument(
191+
"--disable",
192+
type=comma_separated_list,
193+
default="TRIO9" if Plugin.standalone else "",
194+
required=False,
195+
help=(
196+
"Comma-separated list of error codes to disable, similar to flake8"
197+
" --ignore but is additionally more performant as it will disable"
198+
" non-enabled visitors from running instead of just silencing their"
199+
" errors."
200+
),
201+
)
202+
add_argument(
203+
"--autofix",
204+
type=comma_separated_list,
205+
default="",
206+
required=False,
207+
help=(
208+
"Comma-separated list of error-codes to enable autofixing for"
209+
"if implemented. Requires running as a standalone program."
210+
),
211+
)
212+
add_argument(
213+
"--error-on-autofix",
214+
action="store_true",
215+
required=False,
216+
default=False,
217+
help="Whether to also print an error message for autofixed errors",
218+
)
173219
add_argument(
174220
"--no-checkpoint-warning-decorators",
175221
default="asynccontextmanager",
@@ -208,19 +254,6 @@ def add_options(option_manager: OptionManager | ArgumentParser):
208254
"suggesting it be replaced with {value}"
209255
),
210256
)
211-
add_argument(
212-
"--enable-visitor-codes-regex",
213-
type=re.compile, # type: ignore[arg-type]
214-
default=".*",
215-
required=False,
216-
help=(
217-
"Regex string of visitors to enable. Can be used to disable broken "
218-
"visitors, or instead of --select/--disable to select error codes "
219-
"in a way that is more performant. If a visitor raises multiple codes "
220-
"it will not be disabled unless all codes are disabled, but it will "
221-
"not report codes matching this regex."
222-
),
223-
)
224257
add_argument(
225258
"--anyio",
226259
# action=store_true + parse_from_config does seem to work here, despite
@@ -237,7 +270,56 @@ def add_options(option_manager: OptionManager | ArgumentParser):
237270

238271
@staticmethod
239272
def parse_options(options: Namespace):
240-
Plugin.options = options
273+
def get_matching_codes(
274+
patterns: Iterable[str], codes: Iterable[str], msg: str
275+
) -> Iterable[str]:
276+
for pattern in patterns:
277+
for code in codes:
278+
if code.lower().startswith(pattern.lower()):
279+
yield code
280+
281+
autofix_codes: set[str] = set()
282+
enabled_codes: set[str] = set()
283+
disabled_codes: set[str] = {
284+
err_code
285+
for err_class in (*ERROR_CLASSES, *ERROR_CLASSES_CST)
286+
for err_code in err_class.error_codes.keys() # type: ignore[attr-defined]
287+
if len(err_code) == 7 # exclude e.g. TRIO103_anyio_trio
288+
}
289+
290+
if options.autofix:
291+
if not Plugin.standalone:
292+
print("Cannot autofix when run as a flake8 plugin.", file=sys.stderr)
293+
sys.exit(1)
294+
autofix_codes.update(
295+
get_matching_codes(options.autofix, disabled_codes, "autofix")
296+
)
297+
298+
# enable codes
299+
tmp = set(get_matching_codes(options.enable, disabled_codes, "enable"))
300+
enabled_codes |= tmp
301+
disabled_codes -= tmp
302+
303+
# disable codes
304+
tmp = set(get_matching_codes(options.disable, enabled_codes, "disable"))
305+
306+
# if disable has default value, don't disable explicitly enabled codes
307+
if options.disable == ["TRIO9"]:
308+
tmp -= {code for code in options.enable if len(code) == 7}
309+
310+
disabled_codes |= tmp
311+
enabled_codes -= tmp
312+
313+
Plugin._options = Options(
314+
enable=enabled_codes,
315+
disable=disabled_codes,
316+
autofix=autofix_codes,
317+
error_on_autofix=options.error_on_autofix,
318+
no_checkpoint_warning_decorators=options.no_checkpoint_warning_decorators,
319+
startable_in_context_manager=options.startable_in_context_manager,
320+
trio200_blocking_calls=options.trio200_blocking_calls,
321+
anyio=options.anyio,
322+
)
241323

242324

243325
def comma_separated_list(raw_value: str) -> list[str]:

flake8_trio/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,24 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, NamedTuple
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Any, NamedTuple
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Collection
10+
11+
12+
@dataclass
13+
class Options:
14+
# enable and disable have been expanded to contain full codes, and have no overlap
15+
enable: Collection[str]
16+
disable: Collection[str]
17+
autofix: Collection[str]
18+
error_on_autofix: bool
19+
no_checkpoint_warning_decorators: Collection[str]
20+
startable_in_context_manager: Collection[str]
21+
trio200_blocking_calls: dict[str, str]
22+
anyio: bool
623

724

825
class Statement(NamedTuple):

flake8_trio/runner.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import ast
10-
import re
10+
from abc import ABC
1111
from dataclasses import dataclass, field
1212
from typing import TYPE_CHECKING
1313

@@ -21,44 +21,48 @@
2121
)
2222

2323
if TYPE_CHECKING:
24-
from argparse import Namespace
2524
from collections.abc import Iterable
2625

2726
from libcst import Module
2827

29-
from .base import Error
28+
from .base import Error, Options
3029
from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst
3130

3231

3332
@dataclass
3433
class SharedState:
35-
options: Namespace
34+
options: Options
3635
problems: list[Error] = field(default_factory=list)
3736
library: tuple[str, ...] = ()
3837
typed_calls: dict[str, str] = field(default_factory=dict)
3938
variables: dict[str, str] = field(default_factory=dict)
4039

4140

42-
class Flake8TrioRunner(ast.NodeVisitor):
43-
def __init__(self, options: Namespace):
41+
class CommonRunner(ABC): # noqa: B024 # no abstract methods
42+
def __init__(self, options: Options):
4443
super().__init__()
4544
self.state = SharedState(options)
4645

46+
def selected(self, error_codes: dict[str, str]) -> bool:
47+
for code in error_codes:
48+
for enabled in self.state.options.enable, self.state.options.autofix:
49+
if code in enabled:
50+
return True
51+
return False
52+
53+
54+
class Flake8TrioRunner(ast.NodeVisitor, CommonRunner):
55+
def __init__(self, options: Options):
56+
super().__init__(options)
4757
# utility visitors that need to run before the error-checking visitors
4858
self.utility_visitors = {v(self.state) for v in utility_visitors}
4959

5060
self.visitors = {
5161
v(self.state) for v in ERROR_CLASSES if self.selected(v.error_codes)
5262
}
5363

54-
def selected(self, error_codes: dict[str, str]) -> bool:
55-
return any(
56-
re.match(self.state.options.enable_visitor_codes_regex, code)
57-
for code in error_codes
58-
)
59-
6064
@classmethod
61-
def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
65+
def run(cls, tree: ast.AST, options: Options) -> Iterable[Error]:
6266
runner = cls(options)
6367
runner.visit(tree)
6468
yield from runner.state.problems
@@ -104,10 +108,9 @@ def visit(self, node: ast.AST):
104108
subclass.set_state(subclass.outer.pop(node, {}))
105109

106110

107-
class Flake8TrioRunner_cst:
108-
def __init__(self, options: Namespace, module: Module):
109-
super().__init__()
110-
self.state = SharedState(options)
111+
class Flake8TrioRunner_cst(CommonRunner):
112+
def __init__(self, options: Options, module: Module):
113+
super().__init__(options)
111114
self.options = options
112115

113116
# Could possibly enable/disable utility visitors here, if visitors declared
@@ -127,9 +130,3 @@ def run(self) -> Iterable[Error]:
127130
for v in (*self.utility_visitors, *self.visitors):
128131
self.module = cst.MetadataWrapper(self.module).visit(v)
129132
yield from self.state.problems
130-
131-
def selected(self, error_codes: dict[str, str]) -> bool:
132-
return any(
133-
re.match(self.state.options.enable_visitor_codes_regex, code)
134-
for code in error_codes
135-
)

flake8_trio/visitors/flake8triovisitor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import ast
6-
import re
76
from abc import ABC
87
from typing import TYPE_CHECKING, Any, Union
98

@@ -99,7 +98,7 @@ def error(
9998
), "No error code defined, but class has multiple codes"
10099
error_code = next(iter(self.error_codes))
101100
# don't emit an error if this code is disabled in a multi-code visitor
102-
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
101+
elif error_code in self.options.disable:
103102
return
104103

105104
self.__state.problems.append(
@@ -190,9 +189,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
190189
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
191190
state = self.get_state(*attrs, copy=copy)
192191
if node in self.outer:
193-
# not currently used, and not gonna bother adding dedicated test
194-
# visitors atm
195-
self.outer[node].update(state) # pragma: no cover
192+
self.outer[node].update(state)
196193
else:
197194
self.outer[node] = state
198195

@@ -211,10 +208,9 @@ def error(
211208
), "No error code defined, but class has multiple codes"
212209
error_code = next(iter(self.error_codes))
213210
# don't emit an error if this code is disabled in a multi-code visitor
214-
elif not re.match(
215-
self.options.enable_visitor_codes_regex, error_code
216-
): # pragma: no cover
217-
return
211+
# TODO: write test for only one of 910/911 enabled/autofixed
212+
elif error_code in self.options.disable:
213+
return # pragma: no cover
218214
pos = self.get_metadata(PositionProvider, node).start
219215

220216
self.__state.problems.append(
@@ -228,6 +224,12 @@ def error(
228224
)
229225
)
230226

227+
def autofix(self, code: str | None = None):
228+
if code is None:
229+
assert len(self.error_codes) == 1
230+
code = next(iter(self.error_codes))
231+
return code in self.options.autofix
232+
231233
@property
232234
def library(self) -> tuple[str, ...]:
233235
return self.__state.library if self.__state.library else ("trio",)

0 commit comments

Comments
 (0)