@@ -22,24 +22,27 @@ def test_compute_plan(global_execution_env):
22
22
algo_2 = session_2 .add_algo (spec )
23
23
24
24
# create compute plan
25
- cp_spec = factory .create_compute_plan (algo = algo_2 , objective = objective_1 )
25
+ cp_spec = factory .create_compute_plan (objective = objective_1 )
26
26
27
27
# TODO add a testtuple in the compute plan
28
28
29
29
traintuple_spec_1 = cp_spec .add_traintuple (
30
+ algo = algo_2 ,
30
31
dataset = dataset_1 ,
31
32
data_samples = dataset_1 .train_data_sample_keys ,
32
33
)
33
34
34
35
traintuple_spec_2 = cp_spec .add_traintuple (
36
+ algo = algo_2 ,
35
37
dataset = dataset_2 ,
36
38
data_samples = dataset_2 .train_data_sample_keys ,
37
39
)
38
40
39
- _ = cp_spec .add_traintuple (
41
+ cp_spec .add_traintuple (
42
+ algo = algo_2 ,
40
43
dataset = dataset_1 ,
41
44
data_samples = dataset_1 .train_data_sample_keys ,
42
- traintuple_specs = [traintuple_spec_1 , traintuple_spec_2 ],
45
+ in_models_tuples = [traintuple_spec_1 , traintuple_spec_2 ],
43
46
)
44
47
45
48
# submit compute plan and wait for it to complete
@@ -83,25 +86,28 @@ def test_compute_plan_single_session_success(global_execution_env):
83
86
spec = factory .create_algo ()
84
87
algo = session .add_algo (spec )
85
88
86
- cp_spec = factory .create_compute_plan (algo = algo , objective = objective )
89
+ cp_spec = factory .create_compute_plan (objective = objective )
87
90
88
91
traintuple_spec_1 = cp_spec .add_traintuple (
92
+ algo = algo ,
89
93
dataset = dataset ,
90
94
data_samples = [data_sample_1 ]
91
95
)
92
96
cp_spec .add_testtuple (traintuple_spec_1 )
93
97
94
98
traintuple_spec_2 = cp_spec .add_traintuple (
99
+ algo = algo ,
95
100
dataset = dataset ,
96
101
data_samples = [data_sample_2 ],
97
- traintuple_specs = [traintuple_spec_1 ]
102
+ in_models_tuples = [traintuple_spec_1 ]
98
103
)
99
104
cp_spec .add_testtuple (traintuple_spec_2 )
100
105
101
106
traintuple_spec_3 = cp_spec .add_traintuple (
107
+ algo = algo ,
102
108
dataset = dataset ,
103
109
data_samples = [data_sample_3 ],
104
- traintuple_specs = [traintuple_spec_2 ]
110
+ in_models_tuples = [traintuple_spec_2 ]
105
111
)
106
112
cp_spec .add_testtuple (traintuple_spec_3 )
107
113
@@ -150,25 +156,28 @@ def test_compute_plan_single_session_failure(global_execution_env):
150
156
spec = factory .create_algo (py_script = sbt .factory .INVALID_ALGO_SCRIPT )
151
157
algo = session .add_algo (spec )
152
158
153
- cp_spec = factory .create_compute_plan (algo = algo , objective = objective )
159
+ cp_spec = factory .create_compute_plan (objective = objective )
154
160
155
161
traintuple_spec_1 = cp_spec .add_traintuple (
162
+ algo = algo ,
156
163
dataset = dataset ,
157
164
data_samples = [data_sample_1 ]
158
165
)
159
166
cp_spec .add_testtuple (traintuple_spec_1 )
160
167
161
168
traintuple_spec_2 = cp_spec .add_traintuple (
169
+ algo = algo ,
162
170
dataset = dataset ,
163
171
data_samples = [data_sample_2 ],
164
- traintuple_specs = [traintuple_spec_1 ]
172
+ in_models_tuples = [traintuple_spec_1 ]
165
173
)
166
174
cp_spec .add_testtuple (traintuple_spec_2 )
167
175
168
176
traintuple_spec_3 = cp_spec .add_traintuple (
177
+ algo = algo ,
169
178
dataset = dataset ,
170
179
data_samples = [data_sample_3 ],
171
- traintuple_specs = [traintuple_spec_2 ]
180
+ in_models_tuples = [traintuple_spec_2 ]
172
181
)
173
182
cp_spec .add_testtuple (traintuple_spec_3 )
174
183
@@ -193,3 +202,123 @@ def test_compute_plan_single_session_failure(global_execution_env):
193
202
assert cp .compute_plan_id == compute_plan .compute_plan_id
194
203
assert set (cp .traintuple_keys ) == set (compute_plan .traintuples )
195
204
assert set (cp .testtuple_keys ) == set (compute_plan .testtuples )
205
+
206
+
207
+ def test_compute_plan_aggregate_composite_traintuples (factory , session_1 , session_2 ):
208
+ """
209
+ Compute plan version of the `test_aggregate_composite_traintuples` method from `test_execution.py`
210
+ """
211
+ aggregate_worker = session_1 .node_id
212
+ sessions = [session_1 , session_2 ]
213
+ number_of_rounds = 2
214
+
215
+ # register objectives, datasets, and data samples
216
+ datasets = []
217
+ for s in sessions :
218
+ # register one dataset per node
219
+ spec = factory .create_dataset ()
220
+ dataset = s .add_dataset (spec )
221
+ datasets .append (dataset )
222
+
223
+ # register one data sample per dataset per round of aggregation
224
+ for _ in range (number_of_rounds ):
225
+ spec = factory .create_data_sample (test_only = False , datasets = [dataset ])
226
+ s .add_data_sample (spec )
227
+ # reload datasets (to ensure they are properly linked with the created data samples)
228
+ datasets = [
229
+ sessions [i ].get_dataset (d .key )
230
+ for i , d in enumerate (list (datasets ))
231
+ ]
232
+ # register test data on first node
233
+ spec = factory .create_data_sample (test_only = True , datasets = [datasets [0 ]])
234
+ test_data_sample = sessions [0 ].add_data_sample (spec )
235
+ # register objective on first node
236
+ spec = factory .create_objective (
237
+ dataset = datasets [0 ],
238
+ data_samples = [test_data_sample ],
239
+ )
240
+ objective = sessions [0 ].add_objective (spec )
241
+
242
+ # register algos on first node
243
+ spec = factory .create_composite_algo ()
244
+ composite_algo = sessions [0 ].add_composite_algo (spec )
245
+ spec = factory .create_aggregate_algo ()
246
+ aggregate_algo = sessions [0 ].add_aggregate_algo (spec )
247
+
248
+ # launch execution
249
+ previous_aggregatetuple_spec = None
250
+ previous_composite_traintuple_specs = []
251
+
252
+ cp_spec = factory .create_compute_plan (objective = objective )
253
+
254
+ for round_ in range (number_of_rounds ):
255
+ # create composite traintuple on each node
256
+ composite_traintuple_specs = []
257
+ for index , dataset in enumerate (datasets ):
258
+ kwargs = {}
259
+ if previous_aggregatetuple_spec :
260
+ kwargs = {
261
+ 'in_head_model_tuple' : previous_composite_traintuple_specs [index ],
262
+ 'in_trunk_model_tuple' : previous_aggregatetuple_spec ,
263
+ }
264
+ spec = cp_spec .add_composite_traintuple (
265
+ composite_algo = composite_algo ,
266
+ dataset = dataset ,
267
+ data_samples = [dataset .train_data_sample_keys [0 + round_ ]],
268
+ ** kwargs ,
269
+ )
270
+ composite_traintuple_specs .append (spec )
271
+
272
+ # create aggregate on its node
273
+ spec = cp_spec .add_aggregatetuple (
274
+ aggregate_algo = aggregate_algo ,
275
+ worker = aggregate_worker ,
276
+ in_models_tuples = composite_traintuple_specs ,
277
+ )
278
+
279
+ # save state of round
280
+ previous_aggregatetuple_spec = spec
281
+ previous_composite_traintuple_specs = composite_traintuple_specs
282
+
283
+ # last round: create associated testtuple
284
+ for composite_traintuple_spec in previous_composite_traintuple_specs :
285
+ cp_spec .add_testtuple (
286
+ traintuple_spec = composite_traintuple_spec ,
287
+ )
288
+
289
+ session_1 .add_compute_plan (cp_spec ).future ().wait ()
290
+
291
+
292
+ def test_compute_plan_circular_dependency_failure (factory , session ):
293
+ spec = factory .create_dataset ()
294
+ dataset = session .add_dataset (spec )
295
+
296
+ spec = factory .create_algo ()
297
+ algo = session .add_algo (spec )
298
+
299
+ spec = factory .create_data_sample (test_only = False , datasets = [dataset ])
300
+ data_sample = session .add_data_sample (spec )
301
+
302
+ spec = factory .create_objective (dataset = dataset )
303
+ objective = session .add_objective (spec )
304
+
305
+ cp_spec = factory .create_compute_plan (objective = objective )
306
+
307
+ traintuple_spec_1 = cp_spec .add_traintuple (
308
+ dataset = dataset ,
309
+ algo = algo ,
310
+ data_samples = [data_sample ]
311
+ )
312
+
313
+ traintuple_spec_2 = cp_spec .add_traintuple (
314
+ dataset = dataset ,
315
+ algo = algo ,
316
+ data_samples = [data_sample ]
317
+ )
318
+
319
+ traintuple_spec_1 .in_models_ids .append (traintuple_spec_2 .id )
320
+ traintuple_spec_2 .in_models_ids .append (traintuple_spec_1 .id )
321
+
322
+ # TODO make sur the creation is rejected
323
+ cp = session .add_compute_plan (cp_spec )
324
+ assert False
0 commit comments