Skip to content

reimplement trio100 & trio101 with libcst #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@

import ast
import keyword
import os
import re
import tokenize
from argparse import ArgumentTypeError, Namespace
from typing import TYPE_CHECKING

from .runner import Flake8TrioRunner
import libcst as cst

from .runner import Flake8TrioRunner, Flake8TrioRunner_cst
from .visitors import default_disabled_error_codes

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from os import PathLike

from flake8.options.manager import OptionManager
Expand All @@ -34,24 +37,53 @@
__version__ = "23.2.5"


# Enable support in libcst for new grammar
# See e.g. https://github.com/Instagram/LibCST/issues/862
# wrapping the call and restoring old values in case there's other libcst parsers
# in the same environment, which we don't wanna mess up.
def cst_parse_module_native(source: str) -> cst.Module:
var = os.environ.get("LIBCST_PARSER_TYPE")
try:
os.environ["LIBCST_PARSER_TYPE"] = "native"
mod = cst.parse_module(source)
finally:
del os.environ["LIBCST_PARSER_TYPE"]
if var is not None: # pragma: no cover
os.environ["LIBCST_PARSER_TYPE"] = var
return mod


class Plugin:
name = __name__
version = __version__
options: Namespace = Namespace()

def __init__(self, tree: ast.AST):
def __init__(self, tree: ast.AST, lines: Sequence[str]):
super().__init__()
self._tree = tree
source = "".join(lines)

self._module: cst.Module = cst_parse_module_native(source)

@classmethod
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
# only used with --runslow
with tokenize.open(filename) as f:
source = f.read()
return cls(ast.parse(source))
return cls.from_source(source)

# alternative `__init__` to avoid re-splitting and/or re-joining lines
@classmethod
def from_source(cls, source: str) -> Plugin:
plugin = Plugin.__new__(cls)
super(Plugin, plugin).__init__()
plugin._tree = ast.parse(source)
plugin._module = cst_parse_module_native(source)
return plugin

def run(self) -> Iterable[Error]:
yield from Flake8TrioRunner.run(self._tree, self.options)
yield from Flake8TrioRunner_cst(self.options).run(self._module)

@staticmethod
def add_options(option_manager: OptionManager):
Expand Down
32 changes: 30 additions & 2 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from .visitors import ERROR_CLASSES, utility_visitors
import libcst as cst

from .visitors import ERROR_CLASSES, ERROR_CLASSES_CST, utility_visitors

if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Iterable

from libcst import Module

from .base import Error
from .visitors.flake8triovisitor import Flake8TrioVisitor
from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst


@dataclass
Expand Down Expand Up @@ -93,3 +97,27 @@ def visit(self, node: ast.AST):
# restore any outer state that was saved in the visitor method
for subclass in *self.utility_visitors, *self.visitors:
subclass.set_state(subclass.outer.pop(node, {}))


class Flake8TrioRunner_cst:
def __init__(self, options: Namespace):
super().__init__()
self.state = SharedState(options)
self.options = options
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
)

def run(self, module: Module) -> Iterable[Error]:
if not self.visitors:
return
wrapper = cst.MetadataWrapper(module)
for v in self.visitors:
_ = wrapper.visit(v)
yield from self.state.problems

def selected(self, error_codes: dict[str, str]) -> bool:
return any(
re.match(self.state.options.enable_visitor_codes_regex, code)
for code in error_codes
)
5 changes: 4 additions & 1 deletion flake8_trio/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .flake8triovisitor import Flake8TrioVisitor
from .flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst

__all__ = ["ERROR_CLASSES", "default_disabled_error_codes", "utility_visitors"]
ERROR_CLASSES: set[type[Flake8TrioVisitor]] = set()
ERROR_CLASSES_CST: set[type[Flake8TrioVisitor_cst]] = set()
utility_visitors: set[type[Flake8TrioVisitor]] = set()
default_disabled_error_codes: list[str] = []

# Import all visitors so their decorators run, filling the above containers
# This has to be done at the end to avoid circular imports
from . import visitor_utility # isort: skip
from . import visitor100 # isort: skip
from . import visitor101 # isort: skip
from . import visitor102 # isort: skip
from . import visitor103_104 # isort: skip
from . import visitor105 # isort: skip
Expand Down
73 changes: 73 additions & 0 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, Union

import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider

from ..base import Error, Statement

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,3 +161,72 @@ def library_str(self) -> str:
def add_library(self, name: str) -> None:
if name not in self.__state.library:
self.__state.library = self.__state.library + (name,)


class Flake8TrioVisitor_cst(m.MatcherDecoratableTransformer, ABC):
# abstract attribute by not providing a value
error_codes: dict[str, str] # pyright: reportUninitializedInstanceVariable=false
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, shared_state: SharedState):
super().__init__()
self.outer: dict[cst.BaseStatement, dict[str, Any]] = {}
self.__state = shared_state

self.options = self.__state.options

def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
# require attrs, since we inherit a *ton* of stuff which we don't want to copy
assert attrs
res: dict[str, Any] = {}
for attr in attrs:
value = getattr(self, attr)
if copy and hasattr(value, "copy"):
value = value.copy()
res[attr] = value
return res

def set_state(self, attrs: dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
if copy and hasattr(value, "copy"): # pragma: no cover
# not used by visitors yet
value = value.copy()
setattr(self, attr, value)

def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
state = self.get_state(*attrs, copy=copy)
if node in self.outer:
# not currently used, and not gonna bother adding dedicated test
# visitors atm
self.outer[node].update(state) # pragma: no cover
else:
self.outer[node] = state

def error(
self,
node: cst.CSTNode,
*args: str | Statement | int,
error_code: str | None = None,
):
if error_code is None:
assert (
len(self.error_codes) == 1
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif not re.match(
self.options.enable_visitor_codes_regex, error_code
): # pragma: no cover
return
pos = self.get_metadata(PositionProvider, node).start

self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
pos.line,
pos.column,
self.error_codes[error_code],
*args,
)
)
86 changes: 82 additions & 4 deletions flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, NamedTuple, TypeVar

import libcst as cst
import libcst.matchers as m

from ..base import Statement
from . import ERROR_CLASSES, default_disabled_error_codes, utility_visitors
from . import (
ERROR_CLASSES,
ERROR_CLASSES_CST,
default_disabled_error_codes,
utility_visitors,
)

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Iterator, Sequence

from .flake8triovisitor import Flake8TrioVisitor, HasLineCol
from .flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst, HasLineCol

T = TypeVar("T", bound=Flake8TrioVisitor)
T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst)


def error_class(error_class: type[T]) -> type[T]:
Expand All @@ -26,6 +35,12 @@ def error_class(error_class: type[T]) -> type[T]:
return error_class


def error_class_cst(error_class: type[T_CST]) -> type[T_CST]:
assert error_class.error_codes
ERROR_CLASSES_CST.add(error_class)
return error_class


def disabled_by_default(error_class: type[T]) -> type[T]:
assert error_class.error_codes
default_disabled_error_codes.extend(error_class.error_codes)
Expand Down Expand Up @@ -174,3 +189,66 @@ def get_matching_call(
):
return node, node.func.attr, node.func.value.id
return None


# ___ CST helpers ___
def oneof_names(*names: str):
return m.OneOf(*map(m.Name, names))


CSTNode_T = TypeVar("CSTNode_T", bound=cst.CSTNode)


def list_contains(
seq: Sequence[CSTNode_T], matcher: m.BaseMatcherNode
) -> Iterator[CSTNode_T]:
yield from (item for item in seq if m.matches(item, matcher))


class AttributeCall(NamedTuple):
node: cst.Call
base: str
function: str


def with_has_call(
node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> list[AttributeCall]:
res_list: list[AttributeCall] = []
for item in node.items:
if res := m.extract(
item.item,
m.Call(
func=m.Attribute(
value=m.SaveMatchedNode(m.Name(), name="library"),
attr=m.SaveMatchedNode(
oneof_names(*names),
name="function",
),
)
),
):
assert isinstance(item.item, cst.Call)
assert isinstance(res["library"], cst.Name)
assert isinstance(res["function"], cst.Name)
if res["library"].value not in base:
continue
res_list.append(
AttributeCall(item.item, res["library"].value, res["function"].value)
)
return res_list


def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
return any(
list_contains(
func.decorators,
m.Decorator(
decorator=m.OneOf(
oneof_names(*names),
m.Attribute(attr=oneof_names(*names)),
m.Call(func=m.Attribute(attr=oneof_names(*names))),
)
),
)
)
Loading