Skip to content

Commit 4a77c21

Browse files
Merge pull request #27 from twosigma/pickle-compatibility
2 parents 952105d + d4e5531 commit 4a77c21

File tree

9 files changed

+94
-17
lines changed

9 files changed

+94
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ packages = [
1414
]
1515
readme = "README.md"
1616
repository = "https://github.com/twosigma/uberjob"
17-
version = "1.0.1"
17+
version = "1.0.2"
1818

1919
[tool.poetry.dependencies]
2020
networkx = "^2.5"

src/uberjob/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
__author__ = "Daniel Shields, Timothy Shields"
3030
__maintainer__ = "Daniel Shields, Timothy Shields"
3131
32-
__version__ = "1.0.1"
32+
__version__ = "1.0.2"
3333

3434
from uberjob import graph, progress, stores
3535
from uberjob._errors import CallError, NotTransformedError

src/uberjob/_plan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def __init__(self):
4949
self._scope = ()
5050
self._scope_lock = RLock()
5151

52+
def __getstate__(self) -> dict:
53+
return {"graph": self.graph}
54+
55+
def __setstate__(self, state: dict) -> None:
56+
self.graph = state["graph"]
57+
self._scope = ()
58+
self._scope_lock = RLock()
59+
5260
def _call(self, stack_frame, fn: Callable, *args, **kwargs) -> Call:
5361
call = Call(fn, scope=self._scope, stack_frame=stack_frame)
5462
self.graph.add_node(call)

src/uberjob/_transformations/caching.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@
2929

3030

3131
class BarrierType:
32-
__slots__ = ()
32+
def __new__(cls):
33+
return Barrier
34+
35+
def __getnewargs__(self):
36+
return ()
3337

3438
def __repr__(self):
3539
return "Barrier"
3640

3741

38-
Barrier = BarrierType()
42+
Barrier = object.__new__(BarrierType)
3943

4044

4145
def _to_naive_utc_time(value: dt.datetime | None) -> dt.datetime | None:

src/uberjob/_util/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,31 @@ def is_ipython():
3737

3838

3939
class OmittedType:
40-
__slots__ = ()
40+
def __new__(cls):
41+
return Omitted
42+
43+
def __getnewargs__(self):
44+
return ()
4145

4246
def __repr__(self):
4347
return "<...>"
4448

4549

46-
Omitted = OmittedType()
50+
Omitted = object.__new__(OmittedType)
4751

4852

4953
class MissingType:
50-
__slots__ = ()
54+
def __new__(cls):
55+
return Missing
56+
57+
def __getnewargs__(self):
58+
return ()
5159

5260
def __repr__(self):
5361
return "Missing"
5462

5563

56-
Missing = MissingType()
64+
Missing = object.__new__(MissingType)
5765

5866

5967
class Slot:

src/uberjob/_util/traceback.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,18 @@ def __repr__(self):
4848
)
4949

5050

51-
TruncatedStackFrame = None
51+
class TruncatedStackFrameType:
52+
def __new__(cls):
53+
return TruncatedStackFrame
5254

55+
def __getnewargs__(self):
56+
return ()
5357

54-
class TruncatedStackFrameType:
5558
def __repr__(self):
5659
return "TruncatedStackFrame"
5760

58-
def __new__(cls, *args, **kwargs):
59-
if TruncatedStackFrame is not None:
60-
return TruncatedStackFrame
61-
return super().__new__(cls, *args, **kwargs)
62-
6361

64-
TruncatedStackFrame = TruncatedStackFrameType()
62+
TruncatedStackFrame = object.__new__(TruncatedStackFrameType)
6563

6664

6765
MAX_TRACEBACK_DEPTH = 3

tests/test_atoms.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#
2+
# Copyright 2025 Two Sigma Open Source, LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import pickle
18+
import unittest
19+
20+
from uberjob._transformations.caching import BarrierType, Barrier
21+
from uberjob._util import MissingType, Missing, OmittedType, Omitted
22+
from uberjob._util.traceback import TruncatedStackFrameType, TruncatedStackFrame
23+
24+
ATOMS = [
25+
(BarrierType, Barrier, "Barrier"),
26+
(MissingType, Missing, "Missing"),
27+
(OmittedType, Omitted, "<...>"),
28+
(TruncatedStackFrameType, TruncatedStackFrame, "TruncatedStackFrame"),
29+
]
30+
31+
32+
class AtomTestCase(unittest.TestCase):
33+
def test_atom_singleton(self):
34+
for atom_type, atom, _ in ATOMS:
35+
with self.subTest(atom_type_name=atom_type.__name__):
36+
self.assertIs(atom_type(), atom_type())
37+
self.assertIs(atom_type(), atom)
38+
39+
def test_atom_pickle_round_trip_is_atom(self):
40+
for atom_type, atom, _ in ATOMS:
41+
with self.subTest(atom_type_name=atom_type.__name__):
42+
self.assertIs(pickle.loads(pickle.dumps(atom)), atom)
43+
44+
def test_type_of_atom_is_atom_type(self):
45+
for atom_type, atom, _ in ATOMS:
46+
with self.subTest(atom_type_name=atom_type.__name__):
47+
self.assertIs(type(atom), atom_type)
48+
49+
def test_atom_repr(self):
50+
for atom_type, atom, atom_repr in ATOMS:
51+
with self.subTest(atom_type_name=atom_type.__name__):
52+
self.assertEqual(repr(atom), atom_repr)

tests/test_plan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,10 @@ def test_serialize_node_error(self):
405405
self.assertIsInstance(unpickled_exception, NodeError)
406406
self.assertIsInstance(unpickled_exception.node, Call)
407407
self.assertIs(unpickled_exception.node.fn, pow)
408+
409+
def test_plan_pickle_round_trip(self):
410+
plan = uberjob.Plan()
411+
result = plan.call(pow, 3, 2)
412+
plan2, result2 = pickle.loads(pickle.dumps([plan, result]))
413+
self.assertEqual(plan2._scope, ())
414+
self.assertEqual(uberjob.run(plan2, output=result2), 9)

tests/test_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import uberjob
2020

21-
EXPECTED_VERSION = "1.0.1"
21+
EXPECTED_VERSION = "1.0.2"
2222
REPOSITORY_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2323

2424

0 commit comments

Comments
 (0)