@@ -50,6 +50,50 @@ def get(self):
50
50
return self ._asset
51
51
52
52
53
+ class ComputePlanFuture (Future ):
54
+ _keys_properties = {
55
+ 'ComputePlan' : {
56
+ 'traintuple_keys' : 'traintuples' ,
57
+ 'composite_traintuple_keys' : 'composite_traintuples' ,
58
+ 'aggregatetuple_keys' : 'aggregatetuples' ,
59
+ 'testtuple_keys' : 'testtuples' ,
60
+ },
61
+ 'ComputePlanCreated' : {
62
+ 'traintuple_keys' : 'traintuple_keys' ,
63
+ 'composite_traintuple_keys' : 'composite_traintuple_keys' ,
64
+ 'aggregatetuple_keys' : 'aggregatetuple_keys' ,
65
+ 'testtuple_keys' : 'testtuple_keys' ,
66
+
67
+ },
68
+ }
69
+
70
+ def __init__ (self , asset , session ):
71
+ self ._asset = asset
72
+ self ._getter = session .get_compute_plan
73
+ for k , v in enumerate (self ._keys_properties [asset .__class__ .name ]):
74
+ setattr (self , f'_{ k } ' , getattr (asset , v ))
75
+ self ._get_traintuple = session .get_traintuple
76
+ self ._get_composite_traintuple = session .get_composite_traintuple
77
+ self ._get_aggregatetuple = session .get_aggregatetuple
78
+ self ._get_testtuple = session .get_testtuple
79
+
80
+ def wait (self , timeout = FUTURE_TIMEOUT , raises = True ):
81
+ """wait until all tuples are completed (done or failed)."""
82
+ for key in self ._traintuple_keys :
83
+ self ._get_traintuple (key ).future ().wait (timeout , raises )
84
+ for key in self ._composite_traintuple_keys :
85
+ self ._get_composite_traintuple (key ).future ().wait (timeout , raises )
86
+ for key in self ._aggregatetuple_keys :
87
+ self ._get_aggregatetuple (key ).future ().wait (timeout , raises )
88
+ for key in self ._testtuple_keys :
89
+ self ._get_testtuple (key ).future ().wait (timeout , raises )
90
+
91
+ return self .get ()
92
+
93
+ def get (self ):
94
+ return self ._getter (self ._asset .key )
95
+
96
+
53
97
class _FutureMixin (abc .ABC ):
54
98
def attach (self , session ):
55
99
"""Attach session to asset."""
@@ -60,7 +104,11 @@ def future(self):
60
104
"""Returns future from asset."""
61
105
assert hasattr (self , 'status' )
62
106
assert hasattr (self , 'key' )
63
- return Future (self , self ._session )
107
+
108
+ try :
109
+ return self .Meta .FutureCls (self , self ._session )
110
+ except AttributeError :
111
+ return Future (self , self ._session )
64
112
65
113
66
114
def _convert (name ):
@@ -321,24 +369,48 @@ class Meta:
321
369
322
370
323
371
@dataclasses .dataclass
324
- class ComputePlanCreated (_Asset ):
372
+ class ComputePlanCreated (_Asset , _FutureMixin ):
325
373
compute_plan_id : str
326
374
traintuple_keys : typing .List [str ]
375
+ composite_traintuple_keys : typing .List [str ]
376
+ aggregatetuple_keys : typing .List [str ]
327
377
testtuple_keys : typing .List [str ]
328
378
379
+ class Meta :
380
+ FutureCls = ComputePlanFuture
381
+
329
382
330
383
@dataclasses .dataclass
331
384
class ComputePlan (_Asset ):
332
385
compute_plan_id : str
333
386
algo_key : str
334
387
objective_key : str
335
388
traintuples : typing .List [str ]
389
+ composite_traintuples : typing .List [str ]
390
+ aggregatetuples : typing .List [str ]
336
391
testtuples : typing .List [str ]
337
392
393
+ class Meta :
394
+ FutureCls = ComputePlanFuture
395
+
338
396
def __post_init__ (self ):
339
397
if self .testtuples is None :
340
398
self .testtuples = []
341
399
400
+ def list_traintuples (self , session ):
401
+ return session .list_traintuples (filters = [f'traintuple:computePlanId:{ self .compute_plan_id } ' ])
402
+
403
+ def list_composite_traintuples (self , session ):
404
+ return session .list_composite_traintuples (
405
+ filters = [f'composite_traintuple:computePlanId:{ self .compute_plan_id } ' ]
406
+ )
407
+
408
+ def list_aggregatetuples (self , session ):
409
+ return session .list_aggregatetuples (filters = [f'aggregatetuple:computePlanId:{ self .compute_plan_id } ' ])
410
+
411
+ def list_testtuples (self , session ):
412
+ return session .list_testtuples (filters = [f'testtuple:computePlanId:{ self .compute_plan_id } ' ])
413
+
342
414
343
415
@dataclasses .dataclass (frozen = True )
344
416
class Node (_Asset ):
0 commit comments