From 6cb80300747e27b8f3a538375872e6af73911323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 6 Oct 2024 17:52:33 +0200 Subject: [PATCH 1/5] Cache collected types in JSON files --- src/docstub/_analysis.py | 24 +++++++- src/docstub/_cli.py | 11 +++- src/docstub/_utils.py | 97 +++++++++++++++++++++++++++++- tests/test_utils.py | 124 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 247 insertions(+), 9 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 7b85107..25bb476 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -2,15 +2,16 @@ import builtins import collections.abc +import json import logging import re import typing -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pathlib import Path import libcst as cst -from ._utils import accumulate_qualname, module_name_from_path +from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum logger = logging.getLogger(__name__) @@ -260,6 +261,25 @@ def common_known_imports(): class TypeCollector(cst.CSTVisitor): + + class ImportSerializer: + """Implements the FileCacheIO protocol to cache `TypeCollector.collect`""" + + def hash(self, path: Path) -> str: + key = pyfile_checksum(path) + return key + + def serialize(self, path: Path, data: dict[str, KnownImport]) -> None: + raw_data = {qualname: asdict(imp) for qualname, imp in data.items()} + with open(path, "w") as fp: + json.dump(raw_data, fp) + + def deserialize(self, path: Path) -> dict[str, KnownImport]: + with open(path) as fp: + raw_data = json.load(fp) + data = {qualname: KnownImport(**kw) for qualname, kw in raw_data.items()} + return data + @classmethod def collect(cls, file): """Collect importable type annotations in given file. diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 848ccd3..12c0fd3 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -12,6 +12,7 @@ ) from ._config import Config from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets +from ._utils import FileCache from ._version import __version__ logger = logging.getLogger(__name__) @@ -90,9 +91,17 @@ def main(source_dir, out_dir, config_path, verbose): # Build map of known imports known_imports = common_known_imports() + + collect_cached_types = FileCache( + cached_func=TypeCollector.collect, + serializer=TypeCollector.ImportSerializer(), + cache_dir=Path.cwd() / ".docstub_cache", + name="collected_types", + ) + for source_path in walk_source(source_dir): logger.info("collecting types in %s", source_path) - known_imports_in_source = TypeCollector.collect(source_path) + known_imports_in_source = collect_cached_types(source_path) known_imports.update(known_imports_in_source) known_imports.update(KnownImport.many_from_config(config.known_imports)) diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index eeb6779..14de1e3 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -1,9 +1,11 @@ import dataclasses import itertools import re -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path from textwrap import indent +from typing import Protocol +from zlib import crc32 import click @@ -105,6 +107,55 @@ def module_name_from_path(path): return name +def pyfile_checksum(path): + """Compute a unique key for a Python file. + + The key takes into account the given `path`, the relative position if the + file is part of a Python package and the file's content. + + Parameters + ---------- + path : Path + + Returns + ------- + key : str + """ + module_name = module_name_from_path(path).encode() + absolute_path = str(path.resolve()).encode() + with open(path, "rb") as fp: + content = fp.read() + key = crc32(content + module_name + absolute_path) + return key + + +def create_cachedir(path): + """Create a cache directory + + Parameters + ---------- + path : Path + """ + path.mkdir(parents=True, exist_ok=True) + cachdir_tag_path = path / "CACHEDIR.TAG" + cachdir_tag_content = ( + "Signature: 8a477f597d28d172789f06886806bc55\n" + "# This file is a cache directory tag automatically created by docstub.\n" + "# For information about cache directory tags see https://bford.info/cachedir/\n" + ) + if not cachdir_tag_path.is_file(): + with open(cachdir_tag_path, "w") as fp: + fp.write(cachdir_tag_content) + + gitignore_path = path / ".gitignore" + gitignore_content = ( + "# This file is a cache directory tag automatically created by docstub.\n" "*\n" + ) + if not gitignore_path.is_file(): + with open(gitignore_path, "w") as fp: + fp.write(gitignore_content) + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class ContextFormatter: """Format messages in context of a location in a file. @@ -238,3 +289,47 @@ def __post_init__(self): if self.path is not None and not isinstance(self.path, Path): msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" raise TypeError(msg) + + +class FileCacheIO[T](Protocol): + """Defines an interface to serialize and deserialize data in `FileCache`.""" + + def hash(self, *args, **kwargs) -> str: ... + def serialize(self, path: Path, data: T) -> None: ... + def deserialize(self, path: Path) -> T: ... + + +class FileCache: + """Cache results from a function call on disk.""" + + def __init__(self, *, cached_func, serializer, cache_dir, name): + """ + Parameters + ---------- + cached_func : callable + serializer : FileCacheIO + An interface that + cache_dir : Path, optional + """ + self.cached_func = cached_func + self.serializer = serializer + self._cache_dir = cache_dir + self.name = name + + @cached_property + def named_cache_dir(self): + cache_dir = self._cache_dir + create_cachedir(cache_dir) + _named_cache_dir = cache_dir / self.name + _named_cache_dir.mkdir(parents=True, exist_ok=True) + return _named_cache_dir + + def __call__(self, *args, **kwargs): + key = self.serializer.hash(*args, **kwargs) + entry_path = self.named_cache_dir / f"{key}" + if entry_path.is_file(): + entry = self.serializer.deserialize(entry_path) + else: + entry = self.cached_func(*args, **kwargs) + self.serializer.serialize(entry_path, entry) + return entry diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d2937c..d419a37 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,6 @@ -from docstub._utils import module_name_from_path +from collections import defaultdict + +from docstub import _utils class Test_module_name_from_path: @@ -19,7 +21,119 @@ def test_basic(self, tmp_path): else: path.mkdir() - assert module_name_from_path(tmp_path / "foo/__init__.py") == "foo" - assert module_name_from_path(tmp_path / "foo/bar.py") == "foo.bar" - assert module_name_from_path(tmp_path / "foo/baz/__init__.py") == "foo.baz" - assert module_name_from_path(tmp_path / "foo/baz/qux.py") == "foo.baz.qux" + assert _utils.module_name_from_path(tmp_path / "foo/__init__.py") == "foo" + assert _utils.module_name_from_path(tmp_path / "foo/bar.py") == "foo.bar" + assert ( + _utils.module_name_from_path(tmp_path / "foo/baz/__init__.py") == "foo.baz" + ) + assert ( + _utils.module_name_from_path(tmp_path / "foo/baz/qux.py") == "foo.baz.qux" + ) + + +def test_pyfile_checksum(tmp_path): + # Create package + package_dir = tmp_path / "mypackage" + package_dir.mkdir() + package_init = package_dir / "__init__.py" + package_init.touch() + + # Create submodule to be checked + submodule_name = "submodule.py" + submodule_path = package_dir / submodule_name + with submodule_path.open("w") as fp: + fp.write("# First line\n") + + original_key = _utils.pyfile_checksum(submodule_path) + # Check that the key is stable + assert original_key == _utils.pyfile_checksum(submodule_path) + + # Key changes if content changes + with submodule_path.open("a") as fp: + fp.write("# Second line\n") + changed_content_key = _utils.pyfile_checksum(submodule_path) + assert original_key != changed_content_key + + # Key changes if qualname / path of module changes + new_package_dir = package_dir.rename(tmp_path / "newpackage") + qualname_changed_key = _utils.pyfile_checksum(new_package_dir / submodule_name) + assert qualname_changed_key != changed_content_key + + +def test_create_cachedir(tmp_path): + cache_dir = tmp_path / ".test_cache_dir" + assert not cache_dir.exists() + + _utils.create_cachedir(cache_dir) + assert cache_dir.is_dir() + + # Check CACHEDIR.TAG file + cache_tag_path = cache_dir / "CACHEDIR.TAG" + assert cache_tag_path.is_file() + with cache_tag_path.open("r") as fp: + cache_tag_content = fp.read() + assert cache_tag_content.startswith("Signature: 8a477f597d28d172789f06886806bc55\n") + + # Check. gitignore + gitignore_path = cache_dir / ".gitignore" + assert gitignore_path.is_file() + with gitignore_path.open("r") as fp: + gitignore_content = fp.read() + assert "\n*\n" in gitignore_content + + # Check that calling it a second time doesn't raise an error + _utils.create_cachedir(cache_dir) + + +class Test_FileCache: + def test_basic(self, tmp_path): + + class Serializer: + def hash(self, arg): + return str(hash(arg)) + + def serialize(self, path, data): + with path.open("x") as fp: + fp.write(str(data)) + + def deserialize(self, path): + with path.open("r") as fp: + return int(fp.read()) + + counter = defaultdict(lambda: 0) + + def square(x): + counter[x] += 1 + return x * x + + cached_square = _utils.FileCache( + cached_func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Result was cached + cached_file = tmp_path / "test" / str(Serializer().hash(3)) + assert cached_file.is_file() + + # With the square(3) cached, the counter no longer increases + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Using another FileCache will use the existing cache + cached_square_2 = _utils.FileCache( + cached_func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + assert cached_square_2(3) == 9 + assert counter[3] == 1 + + # But using another FileCache with a different name will not hit existing cache + cached_square_3 = _utils.FileCache( + cached_func=square, + serializer=Serializer(), + cache_dir=tmp_path, + name="test2", + ) + assert cached_square_3(3) == 9 + assert counter[3] == 2 From d289c24ce5476c927f62c48778ea4d4209014c74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 7 Oct 2024 10:07:19 +0200 Subject: [PATCH 2/5] Use serializer protocol for serialization only and not for file IO. --- src/docstub/_analysis.py | 30 ++++++++++------ src/docstub/_cli.py | 74 +++++++++++++++++++++++++++++++--------- src/docstub/_utils.py | 52 ++++++++++++++++++---------- tests/test_utils.py | 23 +++++++------ 4 files changed, 123 insertions(+), 56 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 25bb476..49b9708 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -261,23 +261,33 @@ def common_known_imports(): class TypeCollector(cst.CSTVisitor): + """Collect types from a given Python file. + + Examples + -------- + >>> types = TypeCollector.collect(__file__) + >>> types[f"{__name__}.TypeCollector"] + + """ class ImportSerializer: - """Implements the FileCacheIO protocol to cache `TypeCollector.collect`""" + """Implements the `FuncSerializer` protocol to cache calls to `collect`.""" + + suffix = ".json" + encoding = "utf-8" - def hash(self, path: Path) -> str: + def hash_args(self, path: Path) -> str: key = pyfile_checksum(path) return key - def serialize(self, path: Path, data: dict[str, KnownImport]) -> None: - raw_data = {qualname: asdict(imp) for qualname, imp in data.items()} - with open(path, "w") as fp: - json.dump(raw_data, fp) + def serialize(self, data: dict[str, KnownImport]) -> bytes: + primitives = {qualname: asdict(imp) for qualname, imp in data.items()} + raw = json.dumps(primitives).encode(self.encoding) + return raw - def deserialize(self, path: Path) -> dict[str, KnownImport]: - with open(path) as fp: - raw_data = json.load(fp) - data = {qualname: KnownImport(**kw) for qualname, kw in raw_data.items()} + def deserialize(self, raw: bytes) -> dict[str, KnownImport]: + primitives = json.loads(raw.decode(self.encoding)) + data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} return data @classmethod diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 12c0fd3..ed40428 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -1,5 +1,7 @@ import logging import sys +import time +from contextlib import contextmanager from pathlib import Path import click @@ -27,7 +29,7 @@ def _load_configuration(config_path=None): Returns ------- - config : dict[str, Any] + config : ~.Config """ config = Config.from_toml(Config.DEFAULT_CONFIG_PATH) @@ -66,6 +68,58 @@ def _setup_logging(*, verbose): ) +def _build_import_map(config, source_dir): + """Build a map of known imports. + + Parameters + ---------- + config : ~.Config + source_dir : Path + + Returns + ------- + imports : dict[str, ~.KnownImport] + """ + known_imports = common_known_imports() + + collect_cached_types = FileCache( + func=TypeCollector.collect, + serializer=TypeCollector.ImportSerializer(), + cache_dir=Path.cwd() / ".docstub_cache", + name="collected_types", + ) + for source_path in walk_source(source_dir): + logger.info("collecting types in %s", source_path) + known_imports_in_source = collect_cached_types(source_path) + known_imports.update(known_imports_in_source) + + known_imports.update(KnownImport.many_from_config(config.known_imports)) + + return known_imports + + +@contextmanager +def report_execution_time(): + start = time.time() + try: + yield + finally: + stop = time.time() + total_seconds = stop - start + + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + formated_duration = f"{seconds:.3f} s" + if minutes: + formated_duration = f"{minutes} min {formated_duration}" + if hours: + formated_duration = f"{hours} h {formated_duration}" + + click.echo() + click.echo(f"Finished in {formated_duration}") + + @click.command() @click.version_option(__version__) @click.argument("source_dir", type=click.Path(exists=True, file_okay=False)) @@ -83,27 +137,13 @@ def _setup_logging(*, verbose): ) @click.option("-v", "--verbose", count=True, help="Log more details.") @click.help_option("-h", "--help") +@report_execution_time() def main(source_dir, out_dir, config_path, verbose): _setup_logging(verbose=verbose) source_dir = Path(source_dir) config = _load_configuration(config_path) - - # Build map of known imports - known_imports = common_known_imports() - - collect_cached_types = FileCache( - cached_func=TypeCollector.collect, - serializer=TypeCollector.ImportSerializer(), - cache_dir=Path.cwd() / ".docstub_cache", - name="collected_types", - ) - - for source_path in walk_source(source_dir): - logger.info("collecting types in %s", source_path) - known_imports_in_source = collect_cached_types(source_path) - known_imports.update(known_imports_in_source) - known_imports.update(KnownImport.many_from_config(config.known_imports)) + known_imports = _build_import_map(config, source_dir) inspector = StaticInspector( source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index 14de1e3..14935d7 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -291,27 +291,39 @@ def __post_init__(self): raise TypeError(msg) -class FileCacheIO[T](Protocol): - """Defines an interface to serialize and deserialize data in `FileCache`.""" +class FuncSerializer[T](Protocol): + """Defines an interface to serialize and deserialize results of a function.""" - def hash(self, *args, **kwargs) -> str: ... - def serialize(self, path: Path, data: T) -> None: ... - def deserialize(self, path: Path) -> T: ... + suffix: str + + def hash_args(self, *args, **kwargs) -> str: ... + def serialize(self, data: T) -> bytes: ... + def deserialize(self, raw: bytes) -> T: ... class FileCache: - """Cache results from a function call on disk.""" + """Cache results from a function call as a files on disk. + + This class can cache results of a function to the disk. A unique key is + generated from the arguments to the function, and the result is cached + inside a file named after this key. + """ - def __init__(self, *, cached_func, serializer, cache_dir, name): + def __init__(self, *, func, serializer, cache_dir, name): """ Parameters ---------- - cached_func : callable - serializer : FileCacheIO - An interface that - cache_dir : Path, optional + func : callable + The function whose output shall be cached. + serializer : FuncSerializer + An interface that matches the given `func`. It must implement the + `FileCachIO` protocol. + cache_dir : Path + The directory of the cache. + name : str + A unique name to separate parallel caches inside `cache_dir`. """ - self.cached_func = cached_func + self.func = func self.serializer = serializer self._cache_dir = cache_dir self.name = name @@ -325,11 +337,15 @@ def named_cache_dir(self): return _named_cache_dir def __call__(self, *args, **kwargs): - key = self.serializer.hash(*args, **kwargs) - entry_path = self.named_cache_dir / f"{key}" + key = self.serializer.hash_args(*args, **kwargs) + entry_path = self.named_cache_dir / f"{key}{self.serializer.suffix}" if entry_path.is_file(): - entry = self.serializer.deserialize(entry_path) + with entry_path.open("rb") as fp: + raw = fp.read() + data = self.serializer.deserialize(raw) else: - entry = self.cached_func(*args, **kwargs) - self.serializer.serialize(entry_path, entry) - return entry + data = self.func(*args, **kwargs) + raw = self.serializer.serialize(data) + with entry_path.open("xb") as fp: + fp.write(raw) + return data diff --git a/tests/test_utils.py b/tests/test_utils.py index d419a37..12b033e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -89,16 +89,16 @@ class Test_FileCache: def test_basic(self, tmp_path): class Serializer: - def hash(self, arg): + suffix = ".txt" + + def hash_args(self, arg): return str(hash(arg)) - def serialize(self, path, data): - with path.open("x") as fp: - fp.write(str(data)) + def serialize(self, data): + return str(data).encode() - def deserialize(self, path): - with path.open("r") as fp: - return int(fp.read()) + def deserialize(self, raw): + return int(raw.decode()) counter = defaultdict(lambda: 0) @@ -107,14 +107,15 @@ def square(x): return x * x cached_square = _utils.FileCache( - cached_func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" ) assert cached_square(3) == 9 assert counter[3] == 1 # Result was cached - cached_file = tmp_path / "test" / str(Serializer().hash(3)) + entry_name = f"{Serializer().hash_args(3)!s}{Serializer.suffix}" + cached_file = tmp_path / "test" / entry_name assert cached_file.is_file() # With the square(3) cached, the counter no longer increases @@ -123,14 +124,14 @@ def square(x): # Using another FileCache will use the existing cache cached_square_2 = _utils.FileCache( - cached_func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" ) assert cached_square_2(3) == 9 assert counter[3] == 1 # But using another FileCache with a different name will not hit existing cache cached_square_3 = _utils.FileCache( - cached_func=square, + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test2", From 7f4ab158eabde59370ec287ef62681977bca3dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 7 Oct 2024 13:32:55 +0200 Subject: [PATCH 3/5] Include version in cache path This will be useful in case of incompatibilities between different versions. --- src/docstub/_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index ed40428..01e5aab 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -86,7 +86,7 @@ def _build_import_map(config, source_dir): func=TypeCollector.collect, serializer=TypeCollector.ImportSerializer(), cache_dir=Path.cwd() / ".docstub_cache", - name="collected_types", + name=f"{__version__}/collected_types", ) for source_path in walk_source(source_dir): logger.info("collecting types in %s", source_path) From 5a9f8a8e501046c7945c5660c4c658986578eaf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 7 Oct 2024 13:33:51 +0200 Subject: [PATCH 4/5] Move FileCache into its own submodule --- src/docstub/_cache.py | 141 ++++++++++++++++++++++++++++++++++++++++++ src/docstub/_cli.py | 2 +- src/docstub/_utils.py | 90 +-------------------------- tests/test_cache.py | 92 +++++++++++++++++++++++++++ tests/test_utils.py | 82 ------------------------ 5 files changed, 235 insertions(+), 172 deletions(-) create mode 100644 src/docstub/_cache.py create mode 100644 tests/test_cache.py diff --git a/src/docstub/_cache.py b/src/docstub/_cache.py new file mode 100644 index 0000000..4036564 --- /dev/null +++ b/src/docstub/_cache.py @@ -0,0 +1,141 @@ +import logging +from functools import cached_property +from typing import Protocol + +logger = logging.getLogger(__name__) + + +CACHEDIR_TAG_CONTENT = """\ +Signature: 8a477f597d28d172789f06886806bc55\ +# This file is a cache directory tag automatically created by docstub.\n" +# For information about cache directory tags see https://bford.info/cachedir/\n" +""" + + +def _directory_size(path): + """Estimate total size of a directory's content in bytes. + + Parameters + ---------- + path : Path + + Returns + ------- + total_bytes : int + Total size of all objects in bytes. + """ + if not path.is_dir(): + msg = f"{path} doesn't exist, can't determine size" + raise FileNotFoundError(msg) + files = path.rglob("*") + total_bytes = sum(f.stat().st_size for f in files) + return total_bytes + + +def create_cache(path): + """Create a cache directory. + + Parameters + ---------- + path : Path + Directory of the cache. The directory and it's parents will be created if it + doesn't exist yet. + """ + path.mkdir(parents=True, exist_ok=True) + cachdir_tag_path = path / "CACHEDIR.TAG" + cachdir_tag_content = ( + "Signature: 8a477f597d28d172789f06886806bc55\n" + "# This file is a cache directory tag automatically created by docstub.\n" + "# For information about cache directory tags see https://bford.info/cachedir/\n" + ) + if not cachdir_tag_path.is_file(): + with open(cachdir_tag_path, "w") as fp: + fp.write(cachdir_tag_content) + + gitignore_path = path / ".gitignore" + gitignore_content = ( + "# This file is a cache directory tag automatically created by docstub.\n" "*\n" + ) + if not gitignore_path.is_file(): + with open(gitignore_path, "w") as fp: + fp.write(gitignore_content) + + +class FuncSerializer[T](Protocol): + """Defines an interface to serialize and deserialize results of a function. + + This interface is used by `FileCache` to cache results of a + + Attributes + ---------- + suffix : + A suffix corresponding to the format of the serialized data, e.g. ".json". + """ + + suffix: str + + def hash_args(self, *args, **kwargs) -> str: + """Compute a unique hash from the arguments passed to a function.""" + + def serialize(self, data: T) -> bytes: + """Serialize results of a function from `T` to bytes.""" + + def deserialize(self, raw: bytes) -> T: + """Deserialize results of a function from bytes back to `T`.""" + + +class FileCache: + """Cache results from a function call as a files on disk. + + This class can cache results of a function to the disk. A unique key is + generated from the arguments to the function, and the result is cached + inside a file named after this key. + """ + + def __init__(self, *, func, serializer, cache_dir, name): + """ + Parameters + ---------- + func : callable + The function whose output shall be cached. + serializer : FuncSerializer + An interface that matches the given `func`. It must implement the + `FileCachIO` protocol. + cache_dir : Path + The directory of the cache. + name : str + A unique name to separate parallel caches inside `cache_dir`. + """ + self.func = func + self.serializer = serializer + self._cache_dir = cache_dir + self.name = name + + @cached_property + def named_cache_dir(self): + """Path to the named subdirectory inside the cache. + + Warns when cache size exceeds 512 MiB. + """ + cache_dir = self._cache_dir + create_cache(cache_dir) + if _directory_size(cache_dir) > 512 * 1024**2: + logger.warning("cache size at %r exceeds 512 MiB", cache_dir) + _named_cache_dir = cache_dir / self.name + _named_cache_dir.mkdir(parents=True, exist_ok=True) + return _named_cache_dir + + def __call__(self, *args, **kwargs): + """Call the wrapped `func` and cache each result in a file.""" + key = self.serializer.hash_args(*args, **kwargs) + entry_path = self.named_cache_dir / f"{key}{self.serializer.suffix}" + if entry_path.is_file(): + with entry_path.open("rb") as fp: + raw = fp.read() + data = self.serializer.deserialize(raw) + else: + data = self.func(*args, **kwargs) + raw = self.serializer.serialize(data) + with entry_path.open("xb") as fp: + fp.write(raw) + return data diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 01e5aab..4e98aad 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -12,9 +12,9 @@ TypeCollector, common_known_imports, ) +from ._cache import FileCache from ._config import Config from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets -from ._utils import FileCache from ._version import __version__ logger = logging.getLogger(__name__) diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index 14935d7..b85fb9a 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -1,10 +1,9 @@ import dataclasses import itertools import re -from functools import cached_property, lru_cache +from functools import lru_cache from pathlib import Path from textwrap import indent -from typing import Protocol from zlib import crc32 import click @@ -129,33 +128,6 @@ def pyfile_checksum(path): return key -def create_cachedir(path): - """Create a cache directory - - Parameters - ---------- - path : Path - """ - path.mkdir(parents=True, exist_ok=True) - cachdir_tag_path = path / "CACHEDIR.TAG" - cachdir_tag_content = ( - "Signature: 8a477f597d28d172789f06886806bc55\n" - "# This file is a cache directory tag automatically created by docstub.\n" - "# For information about cache directory tags see https://bford.info/cachedir/\n" - ) - if not cachdir_tag_path.is_file(): - with open(cachdir_tag_path, "w") as fp: - fp.write(cachdir_tag_content) - - gitignore_path = path / ".gitignore" - gitignore_content = ( - "# This file is a cache directory tag automatically created by docstub.\n" "*\n" - ) - if not gitignore_path.is_file(): - with open(gitignore_path, "w") as fp: - fp.write(gitignore_content) - - @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class ContextFormatter: """Format messages in context of a location in a file. @@ -289,63 +261,3 @@ def __post_init__(self): if self.path is not None and not isinstance(self.path, Path): msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" raise TypeError(msg) - - -class FuncSerializer[T](Protocol): - """Defines an interface to serialize and deserialize results of a function.""" - - suffix: str - - def hash_args(self, *args, **kwargs) -> str: ... - def serialize(self, data: T) -> bytes: ... - def deserialize(self, raw: bytes) -> T: ... - - -class FileCache: - """Cache results from a function call as a files on disk. - - This class can cache results of a function to the disk. A unique key is - generated from the arguments to the function, and the result is cached - inside a file named after this key. - """ - - def __init__(self, *, func, serializer, cache_dir, name): - """ - Parameters - ---------- - func : callable - The function whose output shall be cached. - serializer : FuncSerializer - An interface that matches the given `func`. It must implement the - `FileCachIO` protocol. - cache_dir : Path - The directory of the cache. - name : str - A unique name to separate parallel caches inside `cache_dir`. - """ - self.func = func - self.serializer = serializer - self._cache_dir = cache_dir - self.name = name - - @cached_property - def named_cache_dir(self): - cache_dir = self._cache_dir - create_cachedir(cache_dir) - _named_cache_dir = cache_dir / self.name - _named_cache_dir.mkdir(parents=True, exist_ok=True) - return _named_cache_dir - - def __call__(self, *args, **kwargs): - key = self.serializer.hash_args(*args, **kwargs) - entry_path = self.named_cache_dir / f"{key}{self.serializer.suffix}" - if entry_path.is_file(): - with entry_path.open("rb") as fp: - raw = fp.read() - data = self.serializer.deserialize(raw) - else: - data = self.func(*args, **kwargs) - raw = self.serializer.serialize(data) - with entry_path.open("xb") as fp: - fp.write(raw) - return data diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..16515d5 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,92 @@ +from collections import defaultdict +from pathlib import Path + +import pytest + +from docstub import _cache + + +def test_directory_size(): + assert _cache._directory_size(Path(__file__).parent) > 0 + with pytest.raises(FileNotFoundError, match="doesn't exist, can't determine size"): + _cache._directory_size(Path("i/don't/exist")) + + +def test_create_cache(tmp_path): + cache_dir = tmp_path / ".test_cache_dir" + assert not cache_dir.exists() + + _cache.create_cache(cache_dir) + assert cache_dir.is_dir() + + # Check CACHEDIR.TAG file + cache_tag_path = cache_dir / "CACHEDIR.TAG" + assert cache_tag_path.is_file() + with cache_tag_path.open("r") as fp: + cache_tag_content = fp.read() + assert cache_tag_content.startswith("Signature: 8a477f597d28d172789f06886806bc55\n") + + # Check. gitignore + gitignore_path = cache_dir / ".gitignore" + assert gitignore_path.is_file() + with gitignore_path.open("r") as fp: + gitignore_content = fp.read() + assert "\n*\n" in gitignore_content + + # Check that calling it a second time doesn't raise an error + _cache.create_cache(cache_dir) + + +class Test_FileCache: + def test_basic(self, tmp_path): + + class Serializer: + suffix = ".txt" + + def hash_args(self, arg): + return str(hash(arg)) + + def serialize(self, data): + return str(data).encode() + + def deserialize(self, raw): + return int(raw.decode()) + + counter = defaultdict(lambda: 0) + + def square(x): + counter[x] += 1 + return x * x + + cached_square = _cache.FileCache( + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Result was cached + entry_name = f"{Serializer().hash_args(3)!s}{Serializer.suffix}" + cached_file = tmp_path / "test" / entry_name + assert cached_file.is_file() + + # With the square(3) cached, the counter no longer increases + assert cached_square(3) == 9 + assert counter[3] == 1 + + # Using another FileCache will use the existing cache + cached_square_2 = _cache.FileCache( + func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" + ) + assert cached_square_2(3) == 9 + assert counter[3] == 1 + + # But using another FileCache with a different name will not hit existing cache + cached_square_3 = _cache.FileCache( + func=square, + serializer=Serializer(), + cache_dir=tmp_path, + name="test2", + ) + assert cached_square_3(3) == 9 + assert counter[3] == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 12b033e..bdd02c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,3 @@ -from collections import defaultdict - from docstub import _utils @@ -58,83 +56,3 @@ def test_pyfile_checksum(tmp_path): new_package_dir = package_dir.rename(tmp_path / "newpackage") qualname_changed_key = _utils.pyfile_checksum(new_package_dir / submodule_name) assert qualname_changed_key != changed_content_key - - -def test_create_cachedir(tmp_path): - cache_dir = tmp_path / ".test_cache_dir" - assert not cache_dir.exists() - - _utils.create_cachedir(cache_dir) - assert cache_dir.is_dir() - - # Check CACHEDIR.TAG file - cache_tag_path = cache_dir / "CACHEDIR.TAG" - assert cache_tag_path.is_file() - with cache_tag_path.open("r") as fp: - cache_tag_content = fp.read() - assert cache_tag_content.startswith("Signature: 8a477f597d28d172789f06886806bc55\n") - - # Check. gitignore - gitignore_path = cache_dir / ".gitignore" - assert gitignore_path.is_file() - with gitignore_path.open("r") as fp: - gitignore_content = fp.read() - assert "\n*\n" in gitignore_content - - # Check that calling it a second time doesn't raise an error - _utils.create_cachedir(cache_dir) - - -class Test_FileCache: - def test_basic(self, tmp_path): - - class Serializer: - suffix = ".txt" - - def hash_args(self, arg): - return str(hash(arg)) - - def serialize(self, data): - return str(data).encode() - - def deserialize(self, raw): - return int(raw.decode()) - - counter = defaultdict(lambda: 0) - - def square(x): - counter[x] += 1 - return x * x - - cached_square = _utils.FileCache( - func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" - ) - - assert cached_square(3) == 9 - assert counter[3] == 1 - - # Result was cached - entry_name = f"{Serializer().hash_args(3)!s}{Serializer.suffix}" - cached_file = tmp_path / "test" / entry_name - assert cached_file.is_file() - - # With the square(3) cached, the counter no longer increases - assert cached_square(3) == 9 - assert counter[3] == 1 - - # Using another FileCache will use the existing cache - cached_square_2 = _utils.FileCache( - func=square, serializer=Serializer(), cache_dir=tmp_path, name="test" - ) - assert cached_square_2(3) == 9 - assert counter[3] == 1 - - # But using another FileCache with a different name will not hit existing cache - cached_square_3 = _utils.FileCache( - func=square, - serializer=Serializer(), - cache_dir=tmp_path, - name="test2", - ) - assert cached_square_3(3) == 9 - assert counter[3] == 2 From 52caa1668bec96867edb6f32af3f3f0b64ec0a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 7 Oct 2024 13:34:34 +0200 Subject: [PATCH 5/5] Improve docstrings of `ImportSerializer` --- src/docstub/_analysis.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 49b9708..4fc471a 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -271,21 +271,24 @@ class TypeCollector(cst.CSTVisitor): """ class ImportSerializer: - """Implements the `FuncSerializer` protocol to cache calls to `collect`.""" + """Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`.""" suffix = ".json" encoding = "utf-8" def hash_args(self, path: Path) -> str: + """Compute a unique hash from the path passed to `TypeCollector.collect`.""" key = pyfile_checksum(path) return key def serialize(self, data: dict[str, KnownImport]) -> bytes: + """Serialize results from `TypeCollector.collect`.""" primitives = {qualname: asdict(imp) for qualname, imp in data.items()} - raw = json.dumps(primitives).encode(self.encoding) + raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) return raw def deserialize(self, raw: bytes) -> dict[str, KnownImport]: + """Deserialize results from `TypeCollector.collect`.""" primitives = json.loads(raw.decode(self.encoding)) data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} return data