diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 7b85107..4fc471a 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -2,15 +2,16 @@ import builtins import collections.abc +import json import logging import re import typing -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pathlib import Path import libcst as cst -from ._utils import accumulate_qualname, module_name_from_path +from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum logger = logging.getLogger(__name__) @@ -260,6 +261,38 @@ def common_known_imports(): class TypeCollector(cst.CSTVisitor): + """Collect types from a given Python file. + + Examples + -------- + >>> types = TypeCollector.collect(__file__) + >>> types[f"{__name__}.TypeCollector"] + + """ + + class ImportSerializer: + """Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`.""" + + suffix = ".json" + encoding = "utf-8" + + def hash_args(self, path: Path) -> str: + """Compute a unique hash from the path passed to `TypeCollector.collect`.""" + key = pyfile_checksum(path) + return key + + def serialize(self, data: dict[str, KnownImport]) -> bytes: + """Serialize results from `TypeCollector.collect`.""" + primitives = {qualname: asdict(imp) for qualname, imp in data.items()} + raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) + return raw + + def deserialize(self, raw: bytes) -> dict[str, KnownImport]: + """Deserialize results from `TypeCollector.collect`.""" + primitives = json.loads(raw.decode(self.encoding)) + data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} + return data + @classmethod def collect(cls, file): """Collect importable type annotations in given file. diff --git a/src/docstub/_cache.py b/src/docstub/_cache.py new file mode 100644 index 0000000..4036564 --- /dev/null +++ b/src/docstub/_cache.py @@ -0,0 +1,141 @@ +import logging +from functools import cached_property +from typing import Protocol + +logger = logging.getLogger(__name__) + + +CACHEDIR_TAG_CONTENT = """\ +Signature: 8a477f597d28d172789f06886806bc55\ +# This file is a cache directory tag automatically created by docstub.\n" +# For information about cache directory tags see https://bford.info/cachedir/\n" +""" + + +def _directory_size(path): + """Estimate total size of a directory's content in bytes. + + Parameters + ---------- + path : Path + + Returns + ------- + total_bytes : int + Total size of all objects in bytes. + """ + if not path.is_dir(): + msg = f"{path} doesn't exist, can't determine size" + raise FileNotFoundError(msg) + files = path.rglob("*") + total_bytes = sum(f.stat().st_size for f in files) + return total_bytes + + +def create_cache(path): + """Create a cache directory. + + Parameters + ---------- + path : Path + Directory of the cache. The directory and it's parents will be created if it + doesn't exist yet. + """ + path.mkdir(parents=True, exist_ok=True) + cachdir_tag_path = path / "CACHEDIR.TAG" + cachdir_tag_content = ( + "Signature: 8a477f597d28d172789f06886806bc55\n" + "# This file is a cache directory tag automatically created by docstub.\n" + "# For information about cache directory tags see https://bford.info/cachedir/\n" + ) + if not cachdir_tag_path.is_file(): + with open(cachdir_tag_path, "w") as fp: + fp.write(cachdir_tag_content) + + gitignore_path = path / ".gitignore" + gitignore_content = ( + "# This file is a cache directory tag automatically created by docstub.\n" "*\n" + ) + if not gitignore_path.is_file(): + with open(gitignore_path, "w") as fp: + fp.write(gitignore_content) + + +class FuncSerializer[T](Protocol): + """Defines an interface to serialize and deserialize results of a function. + + This interface is used by `FileCache` to cache results of a + + Attributes + ---------- + suffix : + A suffix corresponding to the format of the serialized data, e.g. ".json". + """ + + suffix: str + + def hash_args(self, *args, **kwargs) -> str: + """Compute a unique hash from the arguments passed to a function.""" + + def serialize(self, data: T) -> bytes: + """Serialize results of a function from `T` to bytes.""" + + def deserialize(self, raw: bytes) -> T: + """Deserialize results of a function from bytes back to `T`.""" + + +class FileCache: + """Cache results from a function call as a files on disk. + + This class can cache results of a function to the disk. A unique key is + generated from the arguments to the function, and the result is cached + inside a file named after this key. + """ + + def __init__(self, *, func, serializer, cache_dir, name): + """ + Parameters + ---------- + func : callable + The function whose output shall be cached. + serializer : FuncSerializer + An interface that matches the given `func`. It must implement the + `FileCachIO` protocol. + cache_dir : Path + The directory of the cache. + name : str + A unique name to separate parallel caches inside `cache_dir`. + """ + self.func = func + self.serializer = serializer + self._cache_dir = cache_dir + self.name = name + + @cached_property + def named_cache_dir(self): + """Path to the named subdirectory inside the cache. + + Warns when cache size exceeds 512 MiB. + """ + cache_dir = self._cache_dir + create_cache(cache_dir) + if _directory_size(cache_dir) > 512 * 1024**2: + logger.warning("cache size at %r exceeds 512 MiB", cache_dir) + _named_cache_dir = cache_dir / self.name + _named_cache_dir.mkdir(parents=True, exist_ok=True) + return _named_cache_dir + + def __call__(self, *args, **kwargs): + """Call the wrapped `func` and cache each result in a file.""" + key = self.serializer.hash_args(*args, **kwargs) + entry_path = self.named_cache_dir / f"{key}{self.serializer.suffix}" + if entry_path.is_file(): + with entry_path.open("rb") as fp: + raw = fp.read() + data = self.serializer.deserialize(raw) + else: + data = self.func(*args, **kwargs) + raw = self.serializer.serialize(data) + with entry_path.open("xb") as fp: + fp.write(raw) + return data diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 848ccd3..4e98aad 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -1,5 +1,7 @@ import logging import sys +import time +from contextlib import contextmanager from pathlib import Path import click @@ -10,6 +12,7 @@ TypeCollector, common_known_imports, ) +from ._cache import FileCache from ._config import Config from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets from ._version import __version__ @@ -26,7 +29,7 @@ def _load_configuration(config_path=None): Returns ------- - config : dict[str, Any] + config : ~.Config """ config = Config.from_toml(Config.DEFAULT_CONFIG_PATH) @@ -65,6 +68,58 @@ def _setup_logging(*, verbose): ) +def _build_import_map(config, source_dir): + """Build a map of known imports. + + Parameters + ---------- + config : ~.Config + source_dir : Path + + Returns + ------- + imports : dict[str, ~.KnownImport] + """ + known_imports = common_known_imports() + + collect_cached_types = FileCache( + func=TypeCollector.collect, + serializer=TypeCollector.ImportSerializer(), + cache_dir=Path.cwd() / ".docstub_cache", + name=f"{__version__}/collected_types", + ) + for source_path in walk_source(source_dir): + logger.info("collecting types in %s", source_path) + known_imports_in_source = collect_cached_types(source_path) + known_imports.update(known_imports_in_source) + + known_imports.update(KnownImport.many_from_config(config.known_imports)) + + return known_imports + + +@contextmanager +def report_execution_time(): + start = time.time() + try: + yield + finally: + stop = time.time() + total_seconds = stop - start + + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + formated_duration = f"{seconds:.3f} s" + if minutes: + formated_duration = f"{minutes} min {formated_duration}" + if hours: + formated_duration = f"{hours} h {formated_duration}" + + click.echo() + click.echo(f"Finished in {formated_duration}") + + @click.command() @click.version_option(__version__) @click.argument("source_dir", type=click.Path(exists=True, file_okay=False)) @@ -82,19 +137,13 @@ def _setup_logging(*, verbose): ) @click.option("-v", "--verbose", count=True, help="Log more details.") @click.help_option("-h", "--help") +@report_execution_time() def main(source_dir, out_dir, config_path, verbose): _setup_logging(verbose=verbose) source_dir = Path(source_dir) config = _load_configuration(config_path) - - # Build map of known imports - known_imports = common_known_imports() - for source_path in walk_source(source_dir): - logger.info("collecting types in %s", source_path) - known_imports_in_source = TypeCollector.collect(source_path) - known_imports.update(known_imports_in_source) - known_imports.update(KnownImport.many_from_config(config.known_imports)) + known_imports = _build_import_map(config, source_dir) inspector = StaticInspector( source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index eeb6779..b85fb9a 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -4,6 +4,7 @@ from functools import lru_cache from pathlib import Path from textwrap import indent +from zlib import crc32 import click @@ -105,6 +106,28 @@ def module_name_from_path(path): return name +def pyfile_checksum(path): + """Compute a unique key for a Python file. + + The key takes into account the given `path`, the relative position if the + file is part of a Python package and the file's content. + + Parameters + ---------- + path : Path + + Returns + ------- + key : str + """ + module_name = module_name_from_path(path).encode() + absolute_path = str(path.resolve()).encode() + with open(path, "rb") as fp: + content = fp.read() + key = crc32(content + module_name + absolute_path) + return key + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class ContextFormatter: """Format messages in context of a location in a file. diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..16515d5 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,92 @@ +from collections import defaultdict +from pathlib import Path + +import pytest + +from docstub import _cache + + +def test_directory_size(): + assert _cache._directory_size(Path(__file__).parent) > 0 + with pytest.raises(FileNotFoundError, match="doesn't exist, can't determine size"): + _cache._directory_size(Path("i/don't/exist")) + + +def test_create_cache(tmp_path): + cache_dir = tmp_path / ".test_cache_dir" + assert not cache_dir.exists() + + _cache.create_cache(cache_dir) + assert cache_dir.is_dir() + + # Check CACHEDIR.TAG file + cache_tag_path = cache_dir / "CACHEDIR.TAG" + assert cache_tag_path.is_file() + with cache_tag_path.open("r") as fp: + cache_tag_content = fp.read() + assert cache_tag_content.startswith("Signature: 8a477f597d28d172789f06886806bc55\n") + + # Check. gitignore + gitignore_path = cache_dir / ".gitignore" + assert gitignore_path.is_file() + with gitignore_path.open("r") as fp: + gitignore_content = fp.read() + assert "\n*\n" in gitignore_content + + # Check that calling it a second time doesn't raise an error + _cache.create_cache(cache_dir) + + +class Test_FileCache: + def test_basic(self, tmp_path): + + class Serializer: + suffix = ".txt" + + def hash_args(self, arg): + return str(hash(arg)) + + def serialize(self, data): + return str(data).encode() + + def deserialize(self, raw): + return int(raw.decode()) + + counter = defaultdict(lambda: 0) + + def square(x): + counter[x] += 1 + return x * x + + cached_square = _cache.FileCache( + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Result was cached + entry_name = f"{Serializer().hash_args(3)!s}{Serializer.suffix}" + cached_file = tmp_path / "test" / entry_name + assert cached_file.is_file() + + # With the square(3) cached, the counter no longer increases + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Using another FileCache will use the existing cache + cached_square_2 = _cache.FileCache( + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + assert cached_square_2(3) == 9 + assert counter[3] == 1 + + # But using another FileCache with a different name will not hit existing cache + cached_square_3 = _cache.FileCache( + func=square, + serializer=Serializer(), + cache_dir=tmp_path, + name="test2", + ) + assert cached_square_3(3) == 9 + assert counter[3] == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d2937c..bdd02c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from docstub._utils import module_name_from_path +from docstub import _utils class Test_module_name_from_path: @@ -19,7 +19,40 @@ def test_basic(self, tmp_path): else: path.mkdir() - assert module_name_from_path(tmp_path / "foo/__init__.py") == "foo" - assert module_name_from_path(tmp_path / "foo/bar.py") == "foo.bar" - assert module_name_from_path(tmp_path / "foo/baz/__init__.py") == "foo.baz" - assert module_name_from_path(tmp_path / "foo/baz/qux.py") == "foo.baz.qux" + assert _utils.module_name_from_path(tmp_path / "foo/__init__.py") == "foo" + assert _utils.module_name_from_path(tmp_path / "foo/bar.py") == "foo.bar" + assert ( + _utils.module_name_from_path(tmp_path / "foo/baz/__init__.py") == "foo.baz" + ) + assert ( + _utils.module_name_from_path(tmp_path / "foo/baz/qux.py") == "foo.baz.qux" + ) + + +def test_pyfile_checksum(tmp_path): + # Create package + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + package_init = package_dir / "__init__.py" + package_init.touch() + + # Create submodule to be checked + submodule_name = "submodule.py" + submodule_path = package_dir / submodule_name + with submodule_path.open("w") as fp: + fp.write("# First line\n") + + original_key = _utils.pyfile_checksum(submodule_path) + # Check that the key is stable + assert original_key == _utils.pyfile_checksum(submodule_path) + + # Key changes if content changes + with submodule_path.open("a") as fp: + fp.write("# Second line\n") + changed_content_key = _utils.pyfile_checksum(submodule_path) + assert original_key != changed_content_key + + # Key changes if qualname / path of module changes + new_package_dir = package_dir.rename(tmp_path / "newpackage") + qualname_changed_key = _utils.pyfile_checksum(new_package_dir / submodule_name) + assert qualname_changed_key != changed_content_key