Skip to content

Commit 8be2c28

Browse files
committed
simplify PID check
1 parent 2d8cd34 commit 8be2c28

File tree

2 files changed

+24
-58
lines changed

2 files changed

+24
-58
lines changed

src/zarr/storage/_memory.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,6 @@ class ManagedMemoryStore(MemoryStore):
636636
_store_dict: _ManagedStoreDict
637637
_name: str
638638
path: str
639-
_created_pid: int
640639

641640
def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = False) -> None:
642641
# Skip MemoryStore.__init__ and call Store.__init__ directly
@@ -646,7 +645,6 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool =
646645
# Get or create a managed dict from the registry
647646
self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name)
648647
self.path = normalize_path(path)
649-
self._created_pid = os.getpid()
650648

651649
def __str__(self) -> str:
652650
return _dereference_path(f"memory://{self._name}", self.path)
@@ -667,18 +665,6 @@ def name(self) -> str:
667665
"""The name of this store, used in the memory:// URL."""
668666
return self._name
669667

670-
def _check_same_process(self) -> None:
671-
"""Raise an error if this store is being used in a different process."""
672-
current_pid = os.getpid()
673-
if self._created_pid != current_pid:
674-
raise RuntimeError(
675-
f"ManagedMemoryStore '{self._name}' was created in process {self._created_pid} "
676-
f"but is being used in process {current_pid}. "
677-
"ManagedMemoryStore instances cannot be shared across processes because "
678-
"their backing dict is not serialized. Use a persistent store (e.g., "
679-
"LocalStore, ZipStore) for cross-process data sharing."
680-
)
681-
682668
@classmethod
683669
def _from_managed_dict(
684670
cls,
@@ -694,7 +680,6 @@ def _from_managed_dict(
694680
store._store_dict = managed_dict
695681
store._name = name
696682
store.path = normalize_path(path)
697-
store._created_pid = os.getpid()
698683
return store
699684

700685
def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore:
@@ -754,7 +739,6 @@ async def get(
754739
byte_range: ByteRequest | None = None,
755740
) -> Buffer | None:
756741
# docstring inherited
757-
self._check_same_process()
758742
return await super().get(
759743
_dereference_path(self.path, key), prototype=prototype, byte_range=byte_range
760744
)
@@ -765,43 +749,36 @@ async def get_partial_values(
765749
key_ranges: Iterable[tuple[str, ByteRequest | None]],
766750
) -> list[Buffer | None]:
767751
# docstring inherited
768-
self._check_same_process()
769752
key_ranges = [
770753
(_dereference_path(self.path, key), byte_range) for key, byte_range in key_ranges
771754
]
772755
return await super().get_partial_values(prototype, key_ranges)
773756

774757
async def exists(self, key: str) -> bool:
775758
# docstring inherited
776-
self._check_same_process()
777759
return await super().exists(_dereference_path(self.path, key))
778760

779761
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
780762
# docstring inherited
781-
self._check_same_process()
782763
return await super().set(_dereference_path(self.path, key), value, byte_range=byte_range)
783764

784765
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
785766
# docstring inherited
786-
self._check_same_process()
787767
return await super().set_if_not_exists(_dereference_path(self.path, key), value)
788768

789769
async def delete(self, key: str) -> None:
790770
# docstring inherited
791-
self._check_same_process()
792771
return await super().delete(_dereference_path(self.path, key))
793772

794773
async def list(self) -> AsyncIterator[str]:
795774
# docstring inherited
796-
self._check_same_process()
797775
prefix = self.path + "/" if self.path else ""
798776
async for key in super().list():
799777
if key.startswith(prefix):
800778
yield key.removeprefix(prefix)
801779

802780
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
803781
# docstring inherited
804-
self._check_same_process()
805782
# Don't use _dereference_path here because it strips trailing slashes,
806783
# which would break prefix matching (e.g., "fo/" vs "foo/")
807784
full_prefix = f"{self.path}/{prefix}" if self.path else prefix
@@ -811,7 +788,6 @@ async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
811788

812789
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
813790
# docstring inherited
814-
self._check_same_process()
815791
full_prefix = _dereference_path(self.path, prefix)
816792
async for key in super().list_dir(full_prefix):
817793
yield key
@@ -829,16 +805,16 @@ def __reduce__(
829805
identity (name, path, read_only) is preserved. If the original store has
830806
been garbage collected, the unpickled store will have an empty dict.
831807
832-
The original process ID is preserved so that cross-process usage can be
833-
detected and will raise an error.
808+
The current process ID is preserved so that cross-process unpickling can be
809+
detected and will raise an error at unpickle time.
834810
"""
835811
return (
836812
self.__class__,
837813
(self._name,),
838814
{
839815
"path": self.path,
840816
"read_only": self.read_only,
841-
"created_pid": self._created_pid,
817+
"created_pid": os.getpid(),
842818
},
843819
)
844820

@@ -847,8 +823,17 @@ def __setstate__(self, state: dict[str, Any]) -> None:
847823
# The __reduce__ method returns (cls, (name,), state)
848824
# Python calls cls(name) then __setstate__(state)
849825
# But __init__ already set up _store_dict and _name from the registry
850-
# We just need to restore path, read_only, and the original process ID
826+
# We just need to restore path and read_only
851827
self.path = normalize_path(state.get("path", ""))
852828
self._read_only = state.get("read_only", False)
853-
# Preserve the original process ID to detect cross-process usage
854-
self._created_pid = state.get("created_pid", os.getpid())
829+
830+
# Check for cross-process usage - fail fast at unpickle time
831+
created_pid = state.get("created_pid")
832+
if created_pid is not None and created_pid != os.getpid():
833+
raise RuntimeError(
834+
f"ManagedMemoryStore '{self._name}' was created in process {created_pid} "
835+
f"but is being unpickled in process {os.getpid()}. "
836+
"ManagedMemoryStore instances cannot be shared across processes because "
837+
"their backing dict is not serialized. Use a persistent store (e.g., "
838+
"LocalStore, ZipStore) for cross-process data sharing."
839+
)

tests/test_store/test_memory.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -279,44 +279,25 @@ def test_pickle_after_gc(self) -> None:
279279

280280
async def test_cross_process_detection(self) -> None:
281281
"""
282-
Test that using a ManagedMemoryStore in a different process raises an error.
282+
Test that unpickling a ManagedMemoryStore in a different process raises an error.
283283
284284
This prevents silent data loss when a store is pickled and unpickled
285285
in a different process (e.g., with multiprocessing).
286286
"""
287-
import pickle
287+
import os
288288

289289
store = ManagedMemoryStore(name="cross-process-test")
290290
await store.set("key", self.buffer_cls.from_bytes(b"value"))
291291

292-
# Simulate unpickling in a different process by manipulating _created_pid
293-
pickled = pickle.dumps(store)
294-
store2 = pickle.loads(pickled)
295-
296-
# Manually change the created_pid to simulate a different process
297-
store2._created_pid = store2._created_pid + 1
298-
299-
# All operations should raise RuntimeError
300-
with pytest.raises(RuntimeError, match="was created in process"):
301-
await store2.get("key")
302-
303-
with pytest.raises(RuntimeError, match="was created in process"):
304-
await store2.set("key", self.buffer_cls.from_bytes(b"value"))
305-
306-
with pytest.raises(RuntimeError, match="was created in process"):
307-
await store2.exists("key")
308-
309-
with pytest.raises(RuntimeError, match="was created in process"):
310-
await store2.delete("key")
311-
312-
with pytest.raises(RuntimeError, match="was created in process"):
313-
[k async for k in store2.list()]
314-
315-
with pytest.raises(RuntimeError, match="was created in process"):
316-
[k async for k in store2.list_prefix("")]
292+
# Get the reduce tuple and modify the state to simulate a different process
293+
cls, args, state = store.__reduce__()
294+
state["created_pid"] = os.getpid() + 1 # Fake a different process ID
317295

296+
# Manually reconstruct what pickle.loads would do
297+
# This simulates unpickling data that was pickled in a different process
298+
reconstructed = cls(*args)
318299
with pytest.raises(RuntimeError, match="was created in process"):
319-
[k async for k in store2.list_dir("")]
300+
reconstructed.__setstate__(state)
320301

321302
def test_store_supports_writes(self, store: ManagedMemoryStore) -> None:
322303
assert store.supports_writes

0 commit comments

Comments
 (0)