Skip to content

Commit a0b9acb

Browse files
committed
refactor ComputePlanFuture
1 parent 22810a4 commit a0b9acb

File tree

2 files changed

+52
-60
lines changed

2 files changed

+52
-60
lines changed

substratest/assets.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@
1111
FUTURE_TIMEOUT = 120 # seconds
1212

1313

14-
class Future:
14+
class BaseFuture(abc.ABC):
15+
@abc.abstractmethod
16+
def wait(self, timeout=FUTURE_TIMEOUT, raises=True):
17+
raise NotImplementedError
18+
19+
@abc.abstractmethod
20+
def get(self):
21+
raise NotImplementedError
22+
23+
24+
class Future(BaseFuture):
1525
"""Future asset."""
1626
# mapper from asset class name to client getter method
1727
_methods = {
@@ -50,43 +60,25 @@ def get(self):
5060
return self._asset
5161

5262

53-
class ComputePlanFuture(Future):
54-
_keys_properties = {
55-
'ComputePlan': {
56-
'traintuple_keys': 'traintuples',
57-
'composite_traintuple_keys': 'composite_traintuples',
58-
'aggregatetuple_keys': 'aggregatetuples',
59-
'testtuple_keys': 'testtuples',
60-
},
61-
'ComputePlanCreated': {
62-
'traintuple_keys': 'traintuple_keys',
63-
'composite_traintuple_keys': 'composite_traintuple_keys',
64-
'aggregatetuple_keys': 'aggregatetuple_keys',
65-
'testtuple_keys': 'testtuple_keys',
66-
67-
},
68-
}
69-
63+
class ComputePlanFuture(BaseFuture):
7064
def __init__(self, asset, session):
7165
self._asset = asset
7266
self._getter = session.get_compute_plan
73-
for k, v in enumerate(self._keys_properties[asset.__class__.name]):
74-
setattr(self, f'_{k}', getattr(asset, v))
7567
self._get_traintuple = session.get_traintuple
7668
self._get_composite_traintuple = session.get_composite_traintuple
7769
self._get_aggregatetuple = session.get_aggregatetuple
7870
self._get_testtuple = session.get_testtuple
7971

80-
def wait(self, timeout=FUTURE_TIMEOUT, raises=True):
72+
def wait(self, timeout=FUTURE_TIMEOUT):
8173
"""wait until all tuples are completed (done or failed)."""
82-
for key in self._traintuple_keys:
83-
self._get_traintuple(key).future().wait(timeout, raises)
84-
for key in self._composite_traintuple_keys:
85-
self._get_composite_traintuple(key).future().wait(timeout, raises)
86-
for key in self._aggregatetuple_keys:
87-
self._get_aggregatetuple(key).future().wait(timeout, raises)
88-
for key in self._testtuple_keys:
89-
self._get_testtuple(key).future().wait(timeout, raises)
74+
for key in self._asset.traintuple_keys:
75+
self._get_traintuple(key).future().wait(timeout, raises=False)
76+
for key in self._asset.composite_traintuple_keys:
77+
self._get_composite_traintuple(key).future().wait(timeout, raises=False)
78+
for key in self._asset.aggregatetuple_keys:
79+
self._get_aggregatetuple(key).future().wait(timeout, raises=False)
80+
for key in self._asset.testtuple_keys:
81+
self._get_testtuple(key).future().wait(timeout, raises=False)
9082

9183
return self.get()
9284

@@ -95,6 +87,8 @@ def get(self):
9587

9688

9789
class _FutureMixin(abc.ABC):
90+
_future_cls = Future
91+
9892
def attach(self, session):
9993
"""Attach session to asset."""
10094
self._session = session
@@ -104,11 +98,11 @@ def future(self):
10498
"""Returns future from asset."""
10599
assert hasattr(self, 'status')
106100
assert hasattr(self, 'key')
101+
return self._future_cls(self, self._session)
107102

108-
try:
109-
return self.Meta.FutureCls(self, self._session)
110-
except AttributeError:
111-
return Future(self, self._session)
103+
104+
class _ComputePlanFutureMixin(abc.ABC):
105+
_future_cls = ComputePlanFuture
112106

113107

114108
def _convert(name):
@@ -376,27 +370,36 @@ class ComputePlanCreated(_Asset, _FutureMixin):
376370
aggregatetuple_keys: typing.List[str]
377371
testtuple_keys: typing.List[str]
378372

379-
class Meta:
380-
FutureCls = ComputePlanFuture
381-
382373

383374
@dataclasses.dataclass
384375
class ComputePlan(_Asset):
385376
compute_plan_id: str
386-
algo_key: str
387377
objective_key: str
388378
traintuples: typing.List[str]
389379
composite_traintuples: typing.List[str]
390380
aggregatetuples: typing.List[str]
391381
testtuples: typing.List[str]
392382

393-
class Meta:
394-
FutureCls = ComputePlanFuture
395-
396383
def __post_init__(self):
397384
if self.testtuples is None:
398385
self.testtuples = []
399386

387+
@property
388+
def traintuple_keys(self):
389+
return self.traintuples
390+
391+
@property
392+
def composite_traintuple_keys(self):
393+
return self.composite_traintuples
394+
395+
@property
396+
def aggregatetuple_keys(self):
397+
return self.aggregatetuples
398+
399+
@property
400+
def testtuple_keys(self):
401+
return self.testtuples
402+
400403
def list_traintuples(self, session):
401404
return session.list_traintuples(filters=[f'traintuple:computePlanId:{self.compute_plan_id}'])
402405

tests/test_execution_compute_plan.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
2+
import substra
33
import substratest as sbt
44

55

@@ -46,12 +46,9 @@ def test_compute_plan(global_execution_env):
4646
)
4747

4848
# submit compute plan and wait for it to complete
49-
cp = session_1.add_compute_plan(cp_spec)
49+
cp = session_1.add_compute_plan(cp_spec).future().wait()
5050

51-
traintuples = [
52-
session_1.get_traintuple(key).future().wait()
53-
for key in cp.traintuple_keys
54-
]
51+
traintuples = cp.list_traintuples(session_1)
5552

5653
# check all traintuples are done and check they have been executed on the expected
5754
# node
@@ -112,20 +109,10 @@ def test_compute_plan_single_session_success(global_execution_env):
112109
cp_spec.add_testtuple(traintuple_spec_3)
113110

114111
# Submit compute plan and wait for it to complete
115-
cp = session.add_compute_plan(cp_spec)
116-
117-
traintuples = [
118-
session.get_traintuple(key).future().wait()
119-
for key in cp.traintuple_keys
120-
]
121-
122-
testtuples = [
123-
session.get_testtuple(key).future().wait()
124-
for key in cp.testtuple_keys
125-
]
112+
cp = session.add_compute_plan(cp_spec).future().wait()
126113

127114
# All the train/test tuples should succeed
128-
for t in traintuples + testtuples:
115+
for t in cp.list_traintuples(session) + cp.list_testtuples(session):
129116
assert t.status == 'done'
130117

131118
compute_plan = session.get_compute_plan(cp.compute_plan_id)
@@ -319,5 +306,7 @@ def test_compute_plan_circular_dependency_failure(factory, session):
319306
traintuple_spec_2.in_models_ids.append(traintuple_spec_1.id)
320307

321308
# TODO make sur the creation is rejected
322-
cp = session.add_compute_plan(cp_spec)
323-
assert False
309+
with pytest.raises(substra.exceptions.InvalidRequest) as e:
310+
session.add_compute_plan(cp_spec)
311+
312+
assert 'circular' in str(e)

0 commit comments

Comments
 (0)