Skip to content

feature: add support for SageMaker workflow tuning step #2497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
868db55
add helper function to generate no-op (data ingestion only) recipe
jerrypeng7773 May 11, 2021
21bedbb
Merge branch 'aws:master' into master
jerrypeng7773 May 11, 2021
854dd10
separate flow generation by source input type + move generation helpe…
jerrypeng7773 May 11, 2021
8798b65
Merge branch 'aws:master' into master
jerrypeng7773 May 11, 2021
69ae4bd
create an internal helper function to generate output node
jerrypeng7773 May 12, 2021
a6a8449
Merge branch 'master' of github.com:jerrypeng7773/sagemaker-python-sdk
jerrypeng7773 May 12, 2021
2aa256e
Merge branch 'aws:master' into master
jerrypeng7773 May 18, 2021
06557a8
add ingestion test using dw processor via pipeline execution
jerrypeng7773 May 19, 2021
dcbfd13
Merge branch 'aws:master' into master
jerrypeng7773 May 19, 2021
fc6522e
verify the fg query df
jerrypeng7773 May 19, 2021
b6f9371
Merge branch 'master' into master
ahsan-z-khan May 19, 2021
86fa47d
fix tests
jerrypeng7773 May 19, 2021
05ccfa6
Merge branch 'master' into master
ahsan-z-khan May 20, 2021
0716e9f
Merge branch 'aws:master' into master
jerrypeng7773 Jun 14, 2021
7ca5af4
add tuning step support
jerrypeng7773 Jun 24, 2021
8cf18b8
fix docstyle check
jerrypeng7773 Jun 24, 2021
1f95b82
add helper function to get tuning step top performing model s3 uri
jerrypeng7773 Jun 29, 2021
1b9d66b
Merge branch 'aws:master' into master
jerrypeng7773 Jun 30, 2021
603b934
Merge branch 'aws:master' into master
jerrypeng7773 Jun 30, 2021
664f2a8
Merge branch 'master' of github.com:jerrypeng7773/sagemaker-python-sdk
jerrypeng7773 Jun 30, 2021
a8755ec
Merge branch 'master' into master
apogupta2018 Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,45 @@ def create_tuning_job(
"Only one of training_config and training_config_list should be provided."
)

tune_request = self._get_tuning_request(
job_name=job_name,
tuning_config=tuning_config,
training_config=training_config,
training_config_list=training_config_list,
warm_start_config=warm_start_config,
tags=tags,
)

LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)

def _get_tuning_request(
self,
job_name,
tuning_config,
training_config=None,
training_config_list=None,
warm_start_config=None,
tags=None,
):
"""Construct CreateHyperParameterTuningJob request

Args:
job_name (str): Name of the tuning job being created.
tuning_config (dict): Configuration to launch the tuning job.
training_config (dict): Configuration to launch training jobs under the tuning job
using a single algorithm.
training_config_list (list[dict]): A list of configurations to launch training jobs
under the tuning job using one or multiple algorithms. Either training_config
or training_config_list should be provided, but not both.
warm_start_config (dict): Configuration defining the type of warm start and
other required configurations.
tags (list[dict]): List of tags for labeling the tuning job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
Returns:
dict: A dictionary for CreateHyperParameterTuningJob request
"""
tune_request = {
"HyperParameterTuningJobName": job_name,
"HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config),
Expand All @@ -2053,9 +2092,7 @@ def create_tuning_job(
if tags is not None:
tune_request["Tags"] = tags

LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
return tune_request

def describe_tuning_job(self, job_name):
"""Calls DescribeHyperParameterTuningJob API for the given job name, returns the response.
Expand Down
24 changes: 21 additions & 3 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False)
estimator_name: self._prepare_static_hyperparameters(
estimator,
self._hyperparameter_ranges_dict[estimator_name],
include_cls_metadata.get(estimator_name, False),
include_cls_metadata.get(estimator_name, False)
if isinstance(include_cls_metadata, dict)
else include_cls_metadata,
)
for (estimator_name, estimator) in self.estimator_dict.items()
}
Expand Down Expand Up @@ -1460,6 +1462,23 @@ def start_new(cls, tuner, inputs):
sagemaker.tuner._TuningJob: Constructed object that captures all
information about the started job.
"""
tuner_args = cls._get_tuner_args(tuner, inputs)
tuner.sagemaker_session.create_tuning_job(**tuner_args)

return cls(tuner.sagemaker_session, tuner._current_job_name)

@classmethod
def _get_tuner_args(cls, tuner, inputs):
"""Gets a dict of arguments for a new Amazon SageMaker tuning job from the tuner

Args:
tuner (:class:`~sagemaker.tuner.HyperparameterTuner`):
The ``HyperparameterTuner`` instance that started the job.
inputs: Information about the training data. Please refer to the
``fit()`` method of the associated estimator.
Returns:
Dict: dict for `sagemaker.session.Session.tune` method
"""
warm_start_config_req = None
if tuner.warm_start_config:
warm_start_config_req = tuner.warm_start_config.to_input_req()
Expand Down Expand Up @@ -1506,8 +1525,7 @@ def start_new(cls, tuner, inputs):
for estimator_name in sorted(tuner.estimator_dict.keys())
]

tuner.sagemaker_session.create_tuning_job(**tuner_args)
return cls(tuner.sagemaker_session, tuner._current_job_name)
return tuner_args

@staticmethod
def _prepare_training_config(
Expand Down
42 changes: 26 additions & 16 deletions src/sagemaker/workflow/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The properties definitions for workflow."""
from __future__ import absolute_import

from typing import Dict, Union
from typing import Dict, Union, List

import attr

Expand All @@ -40,27 +40,35 @@ def __new__(mcs, *args, **kwargs):
class Properties(metaclass=PropertiesMeta):
"""Properties for use in workflow expressions."""

def __init__(self, path: str, shape_name: str = None):
def __init__(
self,
path: str,
shape_name: str = None,
shape_names: List[str] = None,
):
"""Create a Properties instance representing the given shape.

Args:
path (str): The parent path of the Properties instance.
shape_name (str): The botocore sagemaker service model shape name.
shape_names (str): A List of the botocore sagemaker service model shape name.
"""
self._path = path
self._shape_name = shape_name

shape = Properties._shapes.get(self._shape_name, {})
shape_type = shape.get("type")
if shape_type in Properties._primitive_types:
self.__str__ = shape_name
elif shape_type == "structure":
members = shape["members"]
for key, info in members.items():
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
else:
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
shape_names = [] if shape_names is None else shape_names
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names

for name in self._shape_names:
shape = Properties._shapes.get(name, {})
shape_type = shape.get("type")
if shape_type in Properties._primitive_types:
self.__str__ = name
elif shape_type == "structure":
members = shape["members"]
for key, info in members.items():
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
else:
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])

@property
def expr(self):
Expand All @@ -77,8 +85,10 @@ def __init__(self, path: str, shape_name: str = None):
Args:
path (str): The parent path of the PropertiesList instance.
shape_name (str): The botocore sagemaker service model shape name.
root_shape_name (str): The botocore sagemaker service model shape name.
"""
super(PropertiesList, self).__init__(path, shape_name)
self.shape_name = shape_name
self._items: Dict[Union[int, str], Properties] = dict()

def __getitem__(self, item: Union[int, str]):
Expand All @@ -88,7 +98,7 @@ def __getitem__(self, item: Union[int, str]):
item (Union[int, str]): The index of the item in sequence.
"""
if item not in self._items.keys():
shape = Properties._shapes.get(self._shape_name)
shape = Properties._shapes.get(self.shape_name)
member = shape["member"]["shape"]
if isinstance(item, str):
property_item = Properties(f"{self._path}['{item}']", member)
Expand Down
133 changes: 133 additions & 0 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Processor,
)
from sagemaker.transformer import Transformer, _TransformJob
from sagemaker.tuner import HyperparameterTuner, _TuningJob
from sagemaker.workflow.entities import (
DefaultEnumMeta,
Entity,
Expand All @@ -39,6 +40,7 @@
PropertyFile,
Properties,
)
from sagemaker.workflow.functions import Join


class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
Expand All @@ -51,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
TRAINING = "Training"
TRANSFORM = "Transform"
CALLBACK = "Callback"
TUNING = "Tuning"


@attr.s
Expand Down Expand Up @@ -92,6 +95,7 @@ def add_depends_on(self, step_names: List[str]):
"""Add step names to the current step depends on list"""
if not step_names:
return

if not self.depends_on:
self.depends_on = []
self.depends_on.extend(step_names)
Expand Down Expand Up @@ -429,3 +433,132 @@ def to_request(self) -> RequestType:
property_file.expr for property_file in self.property_files
]
return request_dict


class TuningStep(Step):
"""Tuning step for workflow."""

def __init__(
self,
name: str,
tuner: HyperparameterTuner,
inputs=None,
job_arguments: List[str] = None,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
):
"""Construct a TuningStep, given a `HyperparameterTuner` instance.

In addition to the tuner instance, the other arguments are those that are supplied to
the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.

Args:
name (str): The name of the tuning step.
tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
inputs: Information about the training data. Please refer to the
``fit()`` method of the associated estimator, as this can take
any of the following forms:

* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
If using multiple channels for training data, you can specify
a dict mapping channel names to strings or
:func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
that can provide additional information about the training dataset.
See :func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.session.FileSystemInput) - channel configuration for
a file system data source that can provide additional information as well as
the path to the training dataset.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) -
Amazon SageMaker channel configuration for a file system data source for
Amazon algorithms.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
* (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects,
where each instance is a different channel of training data.
job_arguments (List[str]): A list of strings to be passed into the processing job.
Defaults to `None`.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
depends on
"""
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
self.tuner = tuner
self.inputs = inputs
self.job_arguments = job_arguments
self._properties = Properties(
path=f"Steps.{name}",
shape_names=[
"DescribeHyperParameterTuningJobResponse",
"ListTrainingJobsForHyperParameterTuningJobResponse",
],
)
self.cache_config = cache_config

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `create_hyper_parameter_tuning_job`.

NOTE: The CreateHyperParameterTuningJob request is not quite the
args list that workflow needs.
The HyperParameterTuningJobName attribute cannot be included.
"""
if self.tuner.estimator is not None:
self.tuner.estimator._prepare_for_training()
else:
for _, estimator in self.tuner.estimator_dict.items():
estimator._prepare_for_training()

self.tuner._prepare_for_tuning()
tuner_args = _TuningJob._get_tuner_args(self.tuner, self.inputs)
request_dict = self.tuner.sagemaker_session._get_tuning_request(**tuner_args)
request_dict.pop("HyperParameterTuningJobName")

return request_dict

@property
def properties(self):
"""A Properties object representing

`DescribeHyperParameterTuningJobResponse` and
`ListTrainingJobsForHyperParameterTuningJobResponse` data model.
"""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)

return request_dict

def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = ""):
"""Get the model artifact s3 uri from the top performing training jobs.

Args:
top_k (int): the index of the top performing training job
tuning step stores up to 50 top performing training jobs, hence
a valid top_k value is from 0 to 49. The best training job
model is at index 0
s3_bucket (str): the s3 bucket to store the training job output artifact
prefix (str): the s3 key prefix to store the training job output artifact
"""
values = ["s3:/", s3_bucket]
if prefix != "" and prefix is not None:
values.append(prefix)

return Join(
on="/",
values=values
+ [
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
"output/model.tar.gz",
],
)
2 changes: 1 addition & 1 deletion tests/data/pytorch_mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def train(args):
accuracy = test(model, test_loader, device)
save_model(model, args.model_dir)

logger.debug("Overall test accuracy: {}".format(accuracy))
logger.debug("Overall test accuracy: {};".format(accuracy))


def test(model, test_loader, device):
Expand Down
Loading