Skip to content

Commit 5c48649

Browse files
committed
generic compute plan
1 parent ece44e2 commit 5c48649

File tree

2 files changed

+225
-20
lines changed

2 files changed

+225
-20
lines changed

substratest/factory.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class Permissions:
142142

143143

144144
DEFAULT_PERMISSIONS = Permissions(public=True, authorized_ids=[])
145+
DEFAULT_OUT_MODEL_PERMISSIONS = Permissions(public=False, authorized_ids=[])
145146

146147

147148
@dataclasses.dataclass
@@ -243,12 +244,46 @@ class TesttupleSpec(_Spec):
243244

244245
@dataclasses.dataclass
245246
class ComputePlanTraintupleSpec:
247+
algo_key: str
246248
data_manager_key: str
247249
train_data_sample_keys: str
248250
traintuple_id: str
249251
in_models_ids: typing.List[str]
250252
tag: str
251253

254+
@property
255+
def id(self):
256+
return self.traintuple_id
257+
258+
259+
@dataclasses.dataclass
260+
class ComputePlanAggregatetupleSpec(_Spec):
261+
aggregatetuple_id: str
262+
algo_key: str
263+
worker: str
264+
in_models_ids: typing.List[str]
265+
tag: str
266+
267+
@property
268+
def id(self):
269+
return self.aggregatetuple_id
270+
271+
272+
@dataclasses.dataclass
273+
class ComputePlanCompositeTraintupleSpec(_Spec):
274+
composite_traintuple_id: str
275+
algo_key: str
276+
data_manager_key: str
277+
train_data_sample_keys: str
278+
in_head_model_id: str
279+
in_trunk_model_id: str
280+
tag: str
281+
out_trunk_model_permissions: typing.Dict
282+
283+
@property
284+
def id(self):
285+
return self.composite_traintuple_id
286+
252287

253288
@dataclasses.dataclass
254289
class ComputePlanTesttupleSpec:
@@ -276,26 +311,69 @@ def _get_keys(obj, field='key'):
276311

277312
@dataclasses.dataclass
278313
class ComputePlanSpec(_Spec):
279-
algo_key: str
280314
objective_key: str
281315
traintuples: typing.List[ComputePlanTraintupleSpec]
316+
composite_traintuples: typing.List[ComputePlanCompositeTraintupleSpec]
317+
aggregatetuples: typing.List[ComputePlanAggregatetupleSpec]
282318
testtuples: typing.List[ComputePlanTesttupleSpec]
283319

284-
def add_traintuple(self, dataset, data_samples, traintuple_specs=None, tag=None):
285-
traintuple_specs = traintuple_specs or []
320+
def add_traintuple(self, algo, dataset, data_samples, in_models_tuples=None, tag=''):
321+
in_models_tuples = in_models_tuples or []
286322
spec = ComputePlanTraintupleSpec(
323+
algo_key=algo.key,
287324
traintuple_id=random_uuid(),
288325
data_manager_key=dataset.key,
289326
train_data_sample_keys=_get_keys(data_samples),
290-
in_models_ids=[t.traintuple_id for t in traintuple_specs],
291-
tag=tag or '',
327+
in_models_ids=[t.id for t in in_models_tuples],
328+
tag=tag,
292329
)
293330
self.traintuples.append(spec)
294331
return spec
295332

333+
def add_aggregatetuple(self, aggregate_algo, worker, in_models_tuples=None, tag=''):
334+
in_models_tuples = in_models_tuples or []
335+
336+
for t in in_models_tuples:
337+
assert isinstance(t, (ComputePlanTraintupleSpec, ComputePlanCompositeTraintupleSpec))
338+
339+
spec = ComputePlanAggregatetupleSpec(
340+
aggregatetuple_id=random_uuid(),
341+
algo_key=aggregate_algo.key,
342+
worker=worker,
343+
in_models_ids=[t.id for t in in_models_tuples],
344+
tag=tag,
345+
)
346+
self.aggregatetuples.append(spec)
347+
return spec
348+
349+
def add_composite_traintuple(self, composite_algo, dataset=None, data_samples=None,
350+
in_head_model_tuple=None, in_trunk_model_tuple=None,
351+
out_trunk_model_permissions=None, tag=''):
352+
data_samples = data_samples or []
353+
354+
if in_head_model_tuple and in_trunk_model_tuple:
355+
assert isinstance(in_head_model_tuple, ComputePlanCompositeTraintupleSpec)
356+
assert isinstance(
357+
in_trunk_model_tuple,
358+
(ComputePlanCompositeTraintupleSpec, ComputePlanAggregatetupleSpec)
359+
)
360+
361+
spec = ComputePlanCompositeTraintupleSpec(
362+
composite_traintuple_id=random_uuid(),
363+
algo_key=composite_algo.key,
364+
data_manager_key=dataset.key if dataset else None,
365+
train_data_sample_keys=_get_keys(data_samples),
366+
in_head_model_id=in_head_model_tuple.id if in_head_model_tuple else None,
367+
in_trunk_model_id=in_trunk_model_tuple.id if in_trunk_model_tuple else None,
368+
out_trunk_model_permissions=out_trunk_model_permissions or DEFAULT_OUT_MODEL_PERMISSIONS,
369+
tag=tag,
370+
)
371+
self.composite_traintuples.append(spec)
372+
return spec
373+
296374
def add_testtuple(self, traintuple_spec, tag=None):
297375
spec = ComputePlanTesttupleSpec(
298-
traintuple_id=traintuple_spec.traintuple_id,
376+
traintuple_id=traintuple_spec.id,
299377
tag=tag or '',
300378
)
301379
self.testtuples.append(spec)
@@ -486,8 +564,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
486564
permissions=None):
487565
data_samples = data_samples or []
488566

489-
kwargs = {}
490-
491567
if head_traintuple and trunk_traintuple:
492568
assert isinstance(head_traintuple, assets.CompositeTraintuple)
493569
assert isinstance(
@@ -511,7 +587,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
511587
compute_plan_id=compute_plan_id,
512588
rank=rank,
513589
out_trunk_model_permissions=permissions or DEFAULT_PERMISSIONS,
514-
**kwargs,
515590
)
516591

517592
def create_testtuple(self, traintuple=None, tag=None):
@@ -520,10 +595,11 @@ def create_testtuple(self, traintuple=None, tag=None):
520595
tag=tag,
521596
)
522597

523-
def create_compute_plan(self, algo=None, objective=None):
598+
def create_compute_plan(self, objective=None):
524599
return ComputePlanSpec(
525-
algo_key=algo.key if algo else None,
526600
objective_key=objective.key if objective else None,
527601
traintuples=[],
602+
composite_traintuples=[],
603+
aggregatetuples=[],
528604
testtuples=[],
529605
)

tests/test_execution_compute_plan.py

Lines changed: 138 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,27 @@ def test_compute_plan(global_execution_env):
2222
algo_2 = session_2.add_algo(spec)
2323

2424
# create compute plan
25-
cp_spec = factory.create_compute_plan(algo=algo_2, objective=objective_1)
25+
cp_spec = factory.create_compute_plan(objective=objective_1)
2626

2727
# TODO add a testtuple in the compute plan
2828

2929
traintuple_spec_1 = cp_spec.add_traintuple(
30+
algo=algo_2,
3031
dataset=dataset_1,
3132
data_samples=dataset_1.train_data_sample_keys,
3233
)
3334

3435
traintuple_spec_2 = cp_spec.add_traintuple(
36+
algo=algo_2,
3537
dataset=dataset_2,
3638
data_samples=dataset_2.train_data_sample_keys,
3739
)
3840

39-
_ = cp_spec.add_traintuple(
41+
cp_spec.add_traintuple(
42+
algo=algo_2,
4043
dataset=dataset_1,
4144
data_samples=dataset_1.train_data_sample_keys,
42-
traintuple_specs=[traintuple_spec_1, traintuple_spec_2],
45+
in_models_tuples=[traintuple_spec_1, traintuple_spec_2],
4346
)
4447

4548
# submit compute plan and wait for it to complete
@@ -83,25 +86,28 @@ def test_compute_plan_single_session_success(global_execution_env):
8386
spec = factory.create_algo()
8487
algo = session.add_algo(spec)
8588

86-
cp_spec = factory.create_compute_plan(algo=algo, objective=objective)
89+
cp_spec = factory.create_compute_plan(objective=objective)
8790

8891
traintuple_spec_1 = cp_spec.add_traintuple(
92+
algo=algo,
8993
dataset=dataset,
9094
data_samples=[data_sample_1]
9195
)
9296
cp_spec.add_testtuple(traintuple_spec_1)
9397

9498
traintuple_spec_2 = cp_spec.add_traintuple(
99+
algo=algo,
95100
dataset=dataset,
96101
data_samples=[data_sample_2],
97-
traintuple_specs=[traintuple_spec_1]
102+
in_models_tuples=[traintuple_spec_1]
98103
)
99104
cp_spec.add_testtuple(traintuple_spec_2)
100105

101106
traintuple_spec_3 = cp_spec.add_traintuple(
107+
algo=algo,
102108
dataset=dataset,
103109
data_samples=[data_sample_3],
104-
traintuple_specs=[traintuple_spec_2]
110+
in_models_tuples=[traintuple_spec_2]
105111
)
106112
cp_spec.add_testtuple(traintuple_spec_3)
107113

@@ -150,25 +156,28 @@ def test_compute_plan_single_session_failure(global_execution_env):
150156
spec = factory.create_algo(py_script=sbt.factory.INVALID_ALGO_SCRIPT)
151157
algo = session.add_algo(spec)
152158

153-
cp_spec = factory.create_compute_plan(algo=algo, objective=objective)
159+
cp_spec = factory.create_compute_plan(objective=objective)
154160

155161
traintuple_spec_1 = cp_spec.add_traintuple(
162+
algo=algo,
156163
dataset=dataset,
157164
data_samples=[data_sample_1]
158165
)
159166
cp_spec.add_testtuple(traintuple_spec_1)
160167

161168
traintuple_spec_2 = cp_spec.add_traintuple(
169+
algo=algo,
162170
dataset=dataset,
163171
data_samples=[data_sample_2],
164-
traintuple_specs=[traintuple_spec_1]
172+
in_models_tuples=[traintuple_spec_1]
165173
)
166174
cp_spec.add_testtuple(traintuple_spec_2)
167175

168176
traintuple_spec_3 = cp_spec.add_traintuple(
177+
algo=algo,
169178
dataset=dataset,
170179
data_samples=[data_sample_3],
171-
traintuple_specs=[traintuple_spec_2]
180+
in_models_tuples=[traintuple_spec_2]
172181
)
173182
cp_spec.add_testtuple(traintuple_spec_3)
174183

@@ -193,3 +202,123 @@ def test_compute_plan_single_session_failure(global_execution_env):
193202
assert cp.compute_plan_id == compute_plan.compute_plan_id
194203
assert set(cp.traintuple_keys) == set(compute_plan.traintuples)
195204
assert set(cp.testtuple_keys) == set(compute_plan.testtuples)
205+
206+
207+
def test_compute_plan_aggregate_composite_traintuples(factory, session_1, session_2):
208+
"""
209+
Compute plan version of the `test_aggregate_composite_traintuples` method from `test_execution.py`
210+
"""
211+
aggregate_worker = session_1.node_id
212+
sessions = [session_1, session_2]
213+
number_of_rounds = 2
214+
215+
# register objectives, datasets, and data samples
216+
datasets = []
217+
for s in sessions:
218+
# register one dataset per node
219+
spec = factory.create_dataset()
220+
dataset = s.add_dataset(spec)
221+
datasets.append(dataset)
222+
223+
# register one data sample per dataset per round of aggregation
224+
for _ in range(number_of_rounds):
225+
spec = factory.create_data_sample(test_only=False, datasets=[dataset])
226+
s.add_data_sample(spec)
227+
# reload datasets (to ensure they are properly linked with the created data samples)
228+
datasets = [
229+
sessions[i].get_dataset(d.key)
230+
for i, d in enumerate(list(datasets))
231+
]
232+
# register test data on first node
233+
spec = factory.create_data_sample(test_only=True, datasets=[datasets[0]])
234+
test_data_sample = sessions[0].add_data_sample(spec)
235+
# register objective on first node
236+
spec = factory.create_objective(
237+
dataset=datasets[0],
238+
data_samples=[test_data_sample],
239+
)
240+
objective = sessions[0].add_objective(spec)
241+
242+
# register algos on first node
243+
spec = factory.create_composite_algo()
244+
composite_algo = sessions[0].add_composite_algo(spec)
245+
spec = factory.create_aggregate_algo()
246+
aggregate_algo = sessions[0].add_aggregate_algo(spec)
247+
248+
# launch execution
249+
previous_aggregatetuple_spec = None
250+
previous_composite_traintuple_specs = []
251+
252+
cp_spec = factory.create_compute_plan(objective=objective)
253+
254+
for round_ in range(number_of_rounds):
255+
# create composite traintuple on each node
256+
composite_traintuple_specs = []
257+
for index, dataset in enumerate(datasets):
258+
kwargs = {}
259+
if previous_aggregatetuple_spec:
260+
kwargs = {
261+
'in_head_model_tuple': previous_composite_traintuple_specs[index],
262+
'in_trunk_model_tuple': previous_aggregatetuple_spec,
263+
}
264+
spec = cp_spec.add_composite_traintuple(
265+
composite_algo=composite_algo,
266+
dataset=dataset,
267+
data_samples=[dataset.train_data_sample_keys[0 + round_]],
268+
**kwargs,
269+
)
270+
composite_traintuple_specs.append(spec)
271+
272+
# create aggregate on its node
273+
spec = cp_spec.add_aggregatetuple(
274+
aggregate_algo=aggregate_algo,
275+
worker=aggregate_worker,
276+
in_models_tuples=composite_traintuple_specs,
277+
)
278+
279+
# save state of round
280+
previous_aggregatetuple_spec = spec
281+
previous_composite_traintuple_specs = composite_traintuple_specs
282+
283+
# last round: create associated testtuple
284+
for composite_traintuple_spec in previous_composite_traintuple_specs:
285+
cp_spec.add_testtuple(
286+
traintuple_spec=composite_traintuple_spec,
287+
)
288+
289+
session_1.add_compute_plan(cp_spec).future().wait()
290+
291+
292+
def test_compute_plan_circular_dependency_failure(factory, session):
293+
spec = factory.create_dataset()
294+
dataset = session.add_dataset(spec)
295+
296+
spec = factory.create_algo()
297+
algo = session.add_algo(spec)
298+
299+
spec = factory.create_data_sample(test_only=False, datasets=[dataset])
300+
data_sample = session.add_data_sample(spec)
301+
302+
spec = factory.create_objective(dataset=dataset)
303+
objective = session.add_objective(spec)
304+
305+
cp_spec = factory.create_compute_plan(objective=objective)
306+
307+
traintuple_spec_1 = cp_spec.add_traintuple(
308+
dataset=dataset,
309+
algo=algo,
310+
data_samples=[data_sample]
311+
)
312+
313+
traintuple_spec_2 = cp_spec.add_traintuple(
314+
dataset=dataset,
315+
algo=algo,
316+
data_samples=[data_sample]
317+
)
318+
319+
traintuple_spec_1.in_models_ids.append(traintuple_spec_2.id)
320+
traintuple_spec_2.in_models_ids.append(traintuple_spec_1.id)
321+
322+
# TODO make sur the creation is rejected
323+
cp = session.add_compute_plan(cp_spec)
324+
assert False

0 commit comments

Comments
 (0)