Skip to content

Generic compute plan #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 5, 2019
93 changes: 79 additions & 14 deletions substratest/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
FUTURE_TIMEOUT = 120 # seconds


class Future:
class BaseFuture(abc.ABC):
@abc.abstractmethod
def wait(self, timeout=FUTURE_TIMEOUT, raises=True):
raise NotImplementedError

@abc.abstractmethod
def get(self):
raise NotImplementedError


class Future(BaseFuture):
"""Future asset."""
# mapper from asset class name to client getter method
_methods = {
Expand Down Expand Up @@ -50,17 +60,57 @@ def get(self):
return self._asset


class _FutureMixin(abc.ABC):
class ComputePlanFuture(BaseFuture):
def __init__(self, compute_plan, session):
self._compute_plan = compute_plan
self._session = session

def wait(self, timeout=FUTURE_TIMEOUT):
"""wait until all tuples are completed (done or failed)."""
tuples = (self._compute_plan.list_traintuple(self._session) +
self._compute_plan.list_composite_traintuple(self._session) +
self._compute_plan.list_aggregatetuple(self._session))
tuples = sorted(tuples, key=lambda t: t.rank)
# testtuples do not have a rank attribute
tuples += self._compute_plan.list_testtuple(self._session)

for tuple_ in tuples:
tuple_.future().wait(timeout, raises=False)

return self.get()

def get(self):
return self._session.get_compute_plan(self._compute_plan.compute_plan_id)


class _BaseFutureMixin(abc.ABC):
_future_cls = None

def attach(self, session):
"""Attach session to asset."""
self._session = session
return self

def future(self):
"""Returns future from asset."""
return self._future_cls(self, self._session)


class _FutureMixin(_BaseFutureMixin):
_future_cls = Future

def attach(self, session):
self._session = session
return self

def future(self):
assert hasattr(self, 'status')
assert hasattr(self, 'key')
return Future(self, self._session)
return super().future()


class _ComputePlanFutureMixin(_BaseFutureMixin):
_future_cls = ComputePlanFuture


def _convert(name):
Expand Down Expand Up @@ -321,23 +371,38 @@ class Meta:


@dataclasses.dataclass
class ComputePlanCreated(_Asset):
class ComputePlan(_Asset, _ComputePlanFutureMixin):
compute_plan_id: str
objective_key: str
traintuple_keys: typing.List[str]
composite_traintuple_keys: typing.List[str]
aggregatetuple_keys: typing.List[str]
testtuple_keys: typing.List[str]

def __post_init__(self):
if self.composite_traintuple_keys is None:
self.composite_traintuple_keys = []

@dataclasses.dataclass
class ComputePlan(_Asset):
compute_plan_id: str
algo_key: str
objective_key: str
traintuples: typing.List[str]
testtuples: typing.List[str]
if self.aggregatetuple_keys is None:
self.aggregatetuple_keys = []

def __post_init__(self):
if self.testtuples is None:
self.testtuples = []
if self.testtuple_keys is None:
self.testtuple_keys = []

def list_traintuple(self, session):
return session.list_traintuple(filters=[f'traintuple:computePlanId:{self.compute_plan_id}'])

def list_composite_traintuple(self, session):
return session.list_composite_traintuple(
filters=[f'composite_traintuple:computePlanId:{self.compute_plan_id}']
)

def list_aggregatetuple(self, session):
return session.list_aggregatetuple(filters=[f'aggregatetuple:computePlanId:{self.compute_plan_id}'])

def list_testtuple(self, session):
filters = [f'testtuple:key:{k}' for k in self.testtuple_keys]
return session.list_testtuple(filters=filters)


@dataclasses.dataclass(frozen=True)
Expand Down
12 changes: 6 additions & 6 deletions substratest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def add_testtuple(self, spec):

def add_compute_plan(self, spec):
res = self._client.add_compute_plan(spec.to_dict())
compute_plan = assets.ComputePlanCreated.load(res)
compute_plan = assets.ComputePlan.load(res).attach(self)
self.state.compute_plans.append(compute_plan)
return compute_plan

Expand All @@ -156,7 +156,7 @@ def list_compute_plan(self, *args, **kwargs):

def get_compute_plan(self, *args, **kwargs):
res = self._client.get_compute_plan(*args, **kwargs)
compute_plan = assets.ComputePlan.load(res)
compute_plan = assets.ComputePlan.load(res).attach(self)
self.state.update_compute_plan(compute_plan)
return compute_plan

Expand Down Expand Up @@ -214,7 +214,7 @@ def get_traintuple(self, *args, **kwargs):

def list_traintuple(self, *args, **kwargs):
res = self._client.list_traintuple(*args, **kwargs)
return [assets.Traintuple.load(x) for x in res]
return [assets.Traintuple.load(x).attach(self) for x in res]

def get_aggregatetuple(self, *args, **kwargs):
res = self._client.get_aggregatetuple(*args, **kwargs)
Expand All @@ -224,7 +224,7 @@ def get_aggregatetuple(self, *args, **kwargs):

def list_aggregatetuple(self, *args, **kwargs):
res = self._client.list_aggregatetuple(*args, **kwargs)
return [assets.Aggregatetuple.load(x) for x in res]
return [assets.Aggregatetuple.load(x).attach(self) for x in res]

def get_composite_traintuple(self, *args, **kwargs):
res = self._client.get_composite_traintuple(*args, **kwargs)
Expand All @@ -234,7 +234,7 @@ def get_composite_traintuple(self, *args, **kwargs):

def list_composite_traintuple(self, *args, **kwargs):
res = self._client.list_composite_traintuple(*args, **kwargs)
return [assets.CompositeTraintuple.load(x) for x in res]
return [assets.CompositeTraintuple.load(x).attach(self) for x in res]

def get_testtuple(self, *args, **kwargs):
res = self._client.get_testtuple(*args, **kwargs)
Expand All @@ -244,7 +244,7 @@ def get_testtuple(self, *args, **kwargs):

def list_testtuple(self, *args, **kwargs):
res = self._client.list_testtuple(*args, **kwargs)
return [assets.Testtuple.load(x) for x in res]
return [assets.Testtuple.load(x).attach(self) for x in res]

def list_node(self, *args, **kwargs):
res = self._client.list_node(*args, **kwargs)
Expand Down
98 changes: 87 additions & 11 deletions substratest/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Permissions:


DEFAULT_PERMISSIONS = Permissions(public=True, authorized_ids=[])
DEFAULT_OUT_MODEL_PERMISSIONS = Permissions(public=False, authorized_ids=[])


@dataclasses.dataclass
Expand Down Expand Up @@ -243,12 +244,46 @@ class TesttupleSpec(_Spec):

@dataclasses.dataclass
class ComputePlanTraintupleSpec:
algo_key: str
data_manager_key: str
train_data_sample_keys: str
traintuple_id: str
in_models_ids: typing.List[str]
tag: str

@property
def id(self):
return self.traintuple_id


@dataclasses.dataclass
class ComputePlanAggregatetupleSpec(_Spec):
aggregatetuple_id: str
algo_key: str
worker: str
in_models_ids: typing.List[str]
tag: str

@property
def id(self):
return self.aggregatetuple_id


@dataclasses.dataclass
class ComputePlanCompositeTraintupleSpec(_Spec):
composite_traintuple_id: str
algo_key: str
data_manager_key: str
train_data_sample_keys: str
in_head_model_id: str
in_trunk_model_id: str
tag: str
out_trunk_model_permissions: typing.Dict

@property
def id(self):
return self.composite_traintuple_id


@dataclasses.dataclass
class ComputePlanTesttupleSpec:
Expand Down Expand Up @@ -276,26 +311,69 @@ def _get_keys(obj, field='key'):

@dataclasses.dataclass
class ComputePlanSpec(_Spec):
algo_key: str
objective_key: str
traintuples: typing.List[ComputePlanTraintupleSpec]
composite_traintuples: typing.List[ComputePlanCompositeTraintupleSpec]
aggregatetuples: typing.List[ComputePlanAggregatetupleSpec]
testtuples: typing.List[ComputePlanTesttupleSpec]

def add_traintuple(self, dataset, data_samples, traintuple_specs=None, tag=None):
traintuple_specs = traintuple_specs or []
def add_traintuple(self, algo, dataset, data_samples, in_models=None, tag=''):
in_models = in_models or []
spec = ComputePlanTraintupleSpec(
algo_key=algo.key,
traintuple_id=random_uuid(),
data_manager_key=dataset.key,
train_data_sample_keys=_get_keys(data_samples),
in_models_ids=[t.traintuple_id for t in traintuple_specs],
tag=tag or '',
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.traintuples.append(spec)
return spec

def add_aggregatetuple(self, aggregate_algo, worker, in_models=None, tag=''):
in_models = in_models or []

for t in in_models:
assert isinstance(t, (ComputePlanTraintupleSpec, ComputePlanCompositeTraintupleSpec))

spec = ComputePlanAggregatetupleSpec(
aggregatetuple_id=random_uuid(),
algo_key=aggregate_algo.key,
worker=worker,
in_models_ids=[t.id for t in in_models],
tag=tag,
)
self.aggregatetuples.append(spec)
return spec

def add_composite_traintuple(self, composite_algo, dataset=None, data_samples=None,
in_head_model=None, in_trunk_model=None,
out_trunk_model_permissions=None, tag=''):
data_samples = data_samples or []

if in_head_model and in_trunk_model:
assert isinstance(in_head_model, ComputePlanCompositeTraintupleSpec)
assert isinstance(
in_trunk_model,
(ComputePlanCompositeTraintupleSpec, ComputePlanAggregatetupleSpec)
)

spec = ComputePlanCompositeTraintupleSpec(
composite_traintuple_id=random_uuid(),
algo_key=composite_algo.key,
data_manager_key=dataset.key if dataset else None,
train_data_sample_keys=_get_keys(data_samples),
in_head_model_id=in_head_model.id if in_head_model else None,
in_trunk_model_id=in_trunk_model.id if in_trunk_model else None,
out_trunk_model_permissions=out_trunk_model_permissions or DEFAULT_OUT_MODEL_PERMISSIONS,
tag=tag,
)
self.composite_traintuples.append(spec)
return spec

def add_testtuple(self, traintuple_spec, tag=None):
spec = ComputePlanTesttupleSpec(
traintuple_id=traintuple_spec.traintuple_id,
traintuple_id=traintuple_spec.id,
tag=tag or '',
)
self.testtuples.append(spec)
Expand Down Expand Up @@ -486,8 +564,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
permissions=None):
data_samples = data_samples or []

kwargs = {}

if head_traintuple and trunk_traintuple:
assert isinstance(head_traintuple, assets.CompositeTraintuple)
assert isinstance(
Expand All @@ -511,7 +587,6 @@ def create_composite_traintuple(self, algo=None, objective=None, dataset=None,
compute_plan_id=compute_plan_id,
rank=rank,
out_trunk_model_permissions=permissions or DEFAULT_PERMISSIONS,
**kwargs,
)

def create_testtuple(self, traintuple=None, tag=None):
Expand All @@ -520,10 +595,11 @@ def create_testtuple(self, traintuple=None, tag=None):
tag=tag,
)

def create_compute_plan(self, algo=None, objective=None):
def create_compute_plan(self, objective=None):
return ComputePlanSpec(
algo_key=algo.key if algo else None,
objective_key=objective.key if objective else None,
traintuples=[],
composite_traintuples=[],
aggregatetuples=[],
testtuples=[],
)
Loading