diff --git a/src/pytest_mypy_testing/plugin.py b/src/pytest_mypy_testing/plugin.py index 13cc7d6..58fd4f8 100644 --- a/src/pytest_mypy_testing/plugin.py +++ b/src/pytest_mypy_testing/plugin.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: 2020 David Fritzsche # SPDX-License-Identifier: Apache-2.0 OR MIT +import importlib.util import os import pathlib -import tempfile -from typing import Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union import mypy.api import pytest @@ -20,11 +20,10 @@ PYTEST_VERSION = pytest.__version__ PYTEST_VERSION_INFO = tuple(int(part) for part in PYTEST_VERSION.split(".")[:3]) +have_xdist = importlib.util.find_spec("xdist") is not None + class MypyResult(NamedTuple): - mypy_args: List[str] - returncode: int - output_lines: List[str] file_messages: List[Message] non_item_messages: List[Message] @@ -55,13 +54,15 @@ def __init__( self.mypy_item = mypy_item for mark in self.mypy_item.marks: self.add_marker(mark) + if have_xdist: + self.add_marker(pytest.mark.xdist_group("mypy")) @classmethod def from_parent(cls, parent, name, mypy_item): return super().from_parent(parent=parent, name=name, mypy_item=mypy_item) def runtest(self) -> None: - returncode, actual_messages = self.parent.run_mypy(self.mypy_item) + actual_messages = self.parent.run_mypy(self.mypy_item) errors = diff_message_sequences( actual_messages, self.mypy_item.expected_messages @@ -119,7 +120,7 @@ def __init__( ) self.add_marker("mypy") self.mypy_file = parse_file(self.path, config=config) - self._mypy_result: Optional[MypyResult] = None + COLLECTION.add(self) @classmethod def from_parent(cls, parent, **kwargs): @@ -131,71 +132,78 @@ def collect(self) -> Iterator[PytestMypyTestItem]: parent=self, name="[mypy]" + item.name, mypy_item=item ) - def run_mypy(self, item: MypyTestItem) -> Tuple[int, List[Message]]: - if self._mypy_result is None: - self._mypy_result = self._run_mypy(self.path) - return ( - self._mypy_result.returncode, - sorted( - item.actual_messages + self._mypy_result.non_item_messages, - key=lambda msg: msg.lineno, - ), + def run_mypy(self, item: MypyTestItem) -> List[Message]: + mypy_result = COLLECTION.run_mypy(self) + return sorted( + item.actual_messages + mypy_result.non_item_messages, + key=lambda msg: msg.lineno, ) - def _run_mypy(self, filename: Union[pathlib.Path, os.PathLike, str]) -> MypyResult: - filename = pathlib.Path(filename) - with tempfile.TemporaryDirectory(prefix="pytest-mypy-testing-") as tmp_dir_name: - mypy_cache_dir = os.path.join(tmp_dir_name, "mypy_cache") - os.makedirs(mypy_cache_dir) - - mypy_args = [ - "--cache-dir={}".format(mypy_cache_dir), - "--check-untyped-defs", - "--hide-error-context", - "--no-color-output", - "--no-error-summary", - "--no-pretty", - "--soft-error-limit=-1", - "--no-silence-site-packages", - "--no-warn-unused-configs", - "--show-column-numbers", - "--show-error-codes", - "--show-traceback", - str(filename), - ] - - out, err, returncode = mypy.api.run(mypy_args) - - lines = (out + err).splitlines() - - file_messages = [ - msg - for msg in map(Message.from_output, lines) - if (msg.filename == self.mypy_file.filename) - and not ( + +class MypyFileCollection: + def __init__(self): + self.files: List[PytestMypyFile] = [] + self._mypy_results: Optional[Dict[str, MypyResult]] = None + + def add(self, file: PytestMypyFile): + self.files.append(file) + + def run_mypy(self, file: PytestMypyFile) -> MypyResult: + if self._mypy_results is None: + self._mypy_results = self._run_mypy() + return self._mypy_results[str(file.path)] + + def _run_mypy(self) -> Dict[str, MypyResult]: + mypy_args = [ + "--cache-dir={}".format(self.files[0].config.cache.mkdir("mypy-cache")), + "--check-untyped-defs", + "--hide-error-context", + "--no-color-output", + "--no-error-summary", + "--no-pretty", + "--soft-error-limit=-1", + "--no-warn-unused-configs", + "--show-column-numbers", + "--show-error-codes", + "--show-traceback", + *(str(file.path) for file in self.files), + ] + + out, err, returncode = mypy.api.run(mypy_args) + + messages_by_file = {} + + for line in (out + err).splitlines(): + msg = Message.from_output(line) + if msg.filename and not ( msg.severity is Severity.NOTE - and msg.message - == "See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports" + and msg.message.endswith("#missing-imports") + ): + messages_by_file.setdefault(msg.filename, []).append(msg) + + ret = {} + for file in self.files: + file_messages = messages_by_file.get(str(file.path), []) + + non_item_messages = [] + + for msg in file_messages: + for item in file.mypy_file.items: + if item.lineno <= msg.lineno <= item.end_lineno: + item.actual_messages.append(msg) + break + else: + non_item_messages.append(msg) + + ret[str(file.path)] = MypyResult( + file_messages=file_messages, + non_item_messages=non_item_messages, ) - ] - non_item_messages = [] - - for msg in file_messages: - for item in self.mypy_file.items: - if item.lineno <= msg.lineno <= item.end_lineno: - item.actual_messages.append(msg) - break - else: - non_item_messages.append(msg) - - return MypyResult( - mypy_args=mypy_args, - returncode=returncode, - output_lines=lines, - file_messages=file_messages, - non_item_messages=non_item_messages, - ) + return ret + + +COLLECTION = MypyFileCollection() def pytest_collect_file(file_path: pathlib.Path, parent):