diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3bd2a340..29934dd11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,8 +16,8 @@ repos: rev: 22.3.0 hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + - repo: https://github.com/PyCQA/flake8 + rev: 5.0.4 hooks: - id: flake8 - repo: https://github.com/asottile/seed-isort-config diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index 0619d1a6d..e9229b712 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -1,3 +1,4 @@ +import contextlib import hashlib import inspect import logging @@ -5,7 +6,7 @@ import pickle import tempfile import time -from shutil import move, rmtree +from shutil import rmtree from fsspec import AbstractFileSystem, filesystem from fsspec.callbacks import _DEFAULT_CALLBACK @@ -184,11 +185,9 @@ def save_cache(self): for c in cache.values(): if isinstance(c["blocks"], set): c["blocks"] = list(c["blocks"]) - fd2, fn2 = tempfile.mkstemp() - with open(fd2, "wb") as f: - pickle.dump(cache, f) self._mkcache() - move(fn2, fn) + with atomic_write(fn) as f: + pickle.dump(cache, f) self.cached_files[-1] = cached_files self.last_cache = time.time() @@ -264,7 +263,7 @@ def clear_expired_cache(self, expiry_time=None): if self.cached_files[-1]: cache_path = os.path.join(self.storage[-1], "cache") - with open(cache_path, "wb") as fc: + with atomic_write(cache_path) as fc: pickle.dump(self.cached_files[-1], fc) else: rmtree(self.storage[-1]) @@ -834,3 +833,24 @@ def hash_name(path, same_name): else: hash = hashlib.sha256(path.encode()).hexdigest() return hash + + +@contextlib.contextmanager +def atomic_write(path, mode="wb"): + """ + A context manager that opens a temporary file next to `path` and, on exit, + replaces `path` with the temporary file, thereby updating `path` + atomically. + """ + fd, fn = tempfile.mkstemp( + dir=os.path.dirname(path), prefix=os.path.basename(path) + "-" + ) + try: + with open(fd, mode) as fp: + yield fp + except BaseException: + with contextlib.suppress(FileNotFoundError): + os.unlink(fn) + raise + else: + os.replace(fn, path)