Skip to content

Commit 22810a4

Browse files
committed
compute plan future
1 parent 5c48649 commit 22810a4

File tree

3 files changed

+89
-18
lines changed

3 files changed

+89
-18
lines changed

substratest/assets.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,50 @@ def get(self):
5050
return self._asset
5151

5252

53+
class ComputePlanFuture(Future):
54+
_keys_properties = {
55+
'ComputePlan': {
56+
'traintuple_keys': 'traintuples',
57+
'composite_traintuple_keys': 'composite_traintuples',
58+
'aggregatetuple_keys': 'aggregatetuples',
59+
'testtuple_keys': 'testtuples',
60+
},
61+
'ComputePlanCreated': {
62+
'traintuple_keys': 'traintuple_keys',
63+
'composite_traintuple_keys': 'composite_traintuple_keys',
64+
'aggregatetuple_keys': 'aggregatetuple_keys',
65+
'testtuple_keys': 'testtuple_keys',
66+
67+
},
68+
}
69+
70+
def __init__(self, asset, session):
71+
self._asset = asset
72+
self._getter = session.get_compute_plan
73+
for k, v in enumerate(self._keys_properties[asset.__class__.name]):
74+
setattr(self, f'_{k}', getattr(asset, v))
75+
self._get_traintuple = session.get_traintuple
76+
self._get_composite_traintuple = session.get_composite_traintuple
77+
self._get_aggregatetuple = session.get_aggregatetuple
78+
self._get_testtuple = session.get_testtuple
79+
80+
def wait(self, timeout=FUTURE_TIMEOUT, raises=True):
81+
"""wait until all tuples are completed (done or failed)."""
82+
for key in self._traintuple_keys:
83+
self._get_traintuple(key).future().wait(timeout, raises)
84+
for key in self._composite_traintuple_keys:
85+
self._get_composite_traintuple(key).future().wait(timeout, raises)
86+
for key in self._aggregatetuple_keys:
87+
self._get_aggregatetuple(key).future().wait(timeout, raises)
88+
for key in self._testtuple_keys:
89+
self._get_testtuple(key).future().wait(timeout, raises)
90+
91+
return self.get()
92+
93+
def get(self):
94+
return self._getter(self._asset.key)
95+
96+
5397
class _FutureMixin(abc.ABC):
5498
def attach(self, session):
5599
"""Attach session to asset."""
@@ -60,7 +104,11 @@ def future(self):
60104
"""Returns future from asset."""
61105
assert hasattr(self, 'status')
62106
assert hasattr(self, 'key')
63-
return Future(self, self._session)
107+
108+
try:
109+
return self.Meta.FutureCls(self, self._session)
110+
except AttributeError:
111+
return Future(self, self._session)
64112

65113

66114
def _convert(name):
@@ -321,24 +369,48 @@ class Meta:
321369

322370

323371
@dataclasses.dataclass
324-
class ComputePlanCreated(_Asset):
372+
class ComputePlanCreated(_Asset, _FutureMixin):
325373
compute_plan_id: str
326374
traintuple_keys: typing.List[str]
375+
composite_traintuple_keys: typing.List[str]
376+
aggregatetuple_keys: typing.List[str]
327377
testtuple_keys: typing.List[str]
328378

379+
class Meta:
380+
FutureCls = ComputePlanFuture
381+
329382

330383
@dataclasses.dataclass
331384
class ComputePlan(_Asset):
332385
compute_plan_id: str
333386
algo_key: str
334387
objective_key: str
335388
traintuples: typing.List[str]
389+
composite_traintuples: typing.List[str]
390+
aggregatetuples: typing.List[str]
336391
testtuples: typing.List[str]
337392

393+
class Meta:
394+
FutureCls = ComputePlanFuture
395+
338396
def __post_init__(self):
339397
if self.testtuples is None:
340398
self.testtuples = []
341399

400+
def list_traintuples(self, session):
401+
return session.list_traintuples(filters=[f'traintuple:computePlanId:{self.compute_plan_id}'])
402+
403+
def list_composite_traintuples(self, session):
404+
return session.list_composite_traintuples(
405+
filters=[f'composite_traintuple:computePlanId:{self.compute_plan_id}']
406+
)
407+
408+
def list_aggregatetuples(self, session):
409+
return session.list_aggregatetuples(filters=[f'aggregatetuple:computePlanId:{self.compute_plan_id}'])
410+
411+
def list_testtuples(self, session):
412+
return session.list_testtuples(filters=[f'testtuple:computePlanId:{self.compute_plan_id}'])
413+
342414

343415
@dataclasses.dataclass(frozen=True)
344416
class Node(_Asset):

substratest/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def add_testtuple(self, spec):
117117

118118
def add_compute_plan(self, spec):
119119
res = self._client.add_compute_plan(spec.to_dict())
120-
compute_plan = assets.ComputePlanCreated.load(res)
120+
compute_plan = assets.ComputePlanCreated.load(res).attach(self)
121121
return compute_plan
122122

123123
def list_compute_plan(self, *args, **kwargs):

tests/test_execution_compute_plan.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,19 @@ def test_compute_plan_single_session_failure(global_execution_env):
182182
cp_spec.add_testtuple(traintuple_spec_3)
183183

184184
# Submit compute plan and wait for it to complete
185-
cp = session.add_compute_plan(cp_spec)
185+
cp_created = session.add_compute_plan(cp_spec)
186+
cp = cp_created.future().wait()
186187

187-
traintuples = [
188-
session.get_traintuple(key).future().wait(raises=False)
189-
for key in cp.traintuple_keys
190-
]
191-
192-
testtuples = [
193-
session.get_testtuple(key).future().wait(raises=False)
194-
for key in cp.testtuple_keys
195-
]
188+
traintuples = cp.list_traintuples(session)
189+
testtuples = cp.list_testtuples(session)
196190

197191
# All the train/test tuples should be marked as failed
198192
for t in traintuples + testtuples:
199193
assert t.status == 'failed'
200194

201-
compute_plan = session.get_compute_plan(cp.compute_plan_id)
202-
assert cp.compute_plan_id == compute_plan.compute_plan_id
203-
assert set(cp.traintuple_keys) == set(compute_plan.traintuples)
204-
assert set(cp.testtuple_keys) == set(compute_plan.testtuples)
195+
assert cp_created.compute_plan_id == cp.compute_plan_id
196+
assert set(cp_created.traintuple_keys) == set(cp.traintuples)
197+
assert set(cp_created.testtuple_keys) == set(cp.testtuples)
205198

206199

207200
def test_compute_plan_aggregate_composite_traintuples(factory, session_1, session_2):
@@ -286,7 +279,13 @@ def test_compute_plan_aggregate_composite_traintuples(factory, session_1, sessio
286279
traintuple_spec=composite_traintuple_spec,
287280
)
288281

289-
session_1.add_compute_plan(cp_spec).future().wait()
282+
cp = session_1.add_compute_plan(cp_spec).future().wait()
283+
tuples = (cp.list_traintuple(session_1) +
284+
cp.list_composite_traintuples(session_1) +
285+
cp.list_aggregate_tuples(session_1) +
286+
cp.list_testtuples(session_1))
287+
for t in tuples:
288+
assert t.status == 'done'
290289

291290

292291
def test_compute_plan_circular_dependency_failure(factory, session):

0 commit comments

Comments
 (0)