diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 7163674..30e3d7c 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -15,7 +15,8 @@ import logging from enum import Enum -from stepfunctions.inputs import ExecutionInput, StepInput + +from stepfunctions.inputs import Placeholder from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field from stepfunctions.steps.utils import tags_dict_to_kv_list @@ -29,7 +30,6 @@ SAGEMAKER_SERVICE_NAME = "sagemaker" - class SageMakerApi(Enum): CreateTrainingJob = "createTrainingJob" CreateTransformJob = "createTransformJob" @@ -47,7 +47,7 @@ class TrainingStep(Task): Creates a Task State to execute a `SageMaker Training Job `_. The TrainingStep will also create a model by default, and the model shares the same name as the training job. """ - def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, output_data_config_path=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -55,7 +55,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non job_name (str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms: - * (str) - The S3 location where training data is saved. + * (str or Placeholder) - The S3 location where training data is saved. * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple channels for training data, you can specify a dict mapping channel names to strings or :func:`~sagemaker.inputs.TrainingInput` objects. @@ -75,6 +75,8 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non experiment_config (dict, optional): Specify the experiment config for the training. (Default: None) wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True) tags (list[dict], optional): `List to tags `_ to associate with the resource. + output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model + artifacts and output files). If specified, it overrides the `output_path` property of `estimator`. """ self.estimator = estimator self.job_name = job_name @@ -94,6 +96,11 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateTrainingJob) + # Convert `data` Placeholder to a JSONPath string because sagemaker.workflow.airflow.training_config does not + # accept Placeholder in the `input` argument. We will suffix the 'S3Uri' key in `parameters` with ".$" later. + is_data_placeholder = isinstance(data, Placeholder) + if is_data_placeholder: + data = data.to_jsonpath() if isinstance(job_name, str): parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) @@ -106,9 +113,18 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non if estimator.rules != None: parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules] - if isinstance(job_name, (ExecutionInput, StepInput)): + if isinstance(job_name, Placeholder): parameters['TrainingJobName'] = job_name + if output_data_config_path is not None: + parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path + + if data is not None and is_data_placeholder: + # Replace the 'S3Uri' key with one that supports JSONpath value. + # Support for uri str only: The list will only contain 1 element + data_uri = parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None) + parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri + if hyperparameters is not None: if estimator.hyperparameters() is not None: hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters()) @@ -237,7 +253,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= join_source=join_source ) - if isinstance(job_name, (ExecutionInput, StepInput)): + if isinstance(job_name, Placeholder): parameters['TransformJobName'] = job_name parameters['ModelName'] = model_name @@ -506,7 +522,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp else: parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) - if isinstance(job_name, (ExecutionInput, StepInput)): + if isinstance(job_name, Placeholder): parameters['ProcessingJobName'] = job_name if experiment_config is not None: diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 9c86457..9669a73 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -17,7 +17,7 @@ from stepfunctions.exceptions import DuplicateStatesInChain from stepfunctions.steps.fields import Field -from stepfunctions.inputs import ExecutionInput, StepInput +from stepfunctions.inputs import Placeholder, StepInput logger = logging.getLogger('stepfunctions.states') @@ -53,7 +53,7 @@ def _replace_placeholders(self, params): return params modified_parameters = {} for k, v in params.items(): - if isinstance(v, (ExecutionInput, StepInput)): + if isinstance(v, Placeholder): modified_key = "{key}.$".format(key=k) modified_parameters[modified_key] = v.to_jsonpath() elif isinstance(v, dict): diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 73d1a5b..63c060a 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -104,6 +104,7 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf state_machine_delete_wait(sfn_client, workflow.state_machine_arn) # End of Cleanup + def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn): # Build workflow definition model_name = generate_job_name() diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 7645f85..c643468 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -26,6 +26,7 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput from unittest.mock import MagicMock, patch +from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep from stepfunctions.steps.sagemaker import tuning_config @@ -224,6 +225,7 @@ def test_training_step_creation(pca_estimator): 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Training' }, + output_data_config_path='s3://sagemaker-us-east-1-111111111111', tags=DEFAULT_TAGS, ) assert step.to_dict() == { @@ -234,7 +236,7 @@ def test_training_step_creation(pca_estimator): 'TrainingInputMode': 'File' }, 'OutputDataConfig': { - 'S3OutputPath': 's3://sagemaker/models' + 'S3OutputPath': 's3://sagemaker-us-east-1-111111111111' }, 'StoppingCondition': { 'MaxRuntimeInSeconds': 86400 @@ -265,6 +267,81 @@ def test_training_step_creation(pca_estimator): } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_training_step_creation_with_placeholders(pca_estimator): + execution_input = ExecutionInput(schema={ + 'Data': str, + 'OutputPath': str, + }) + + step_input = StepInput(schema={ + 'JobName': str, + }) + + step = TrainingStep('Training', + estimator=pca_estimator, + job_name=step_input['JobName'], + data=execution_input['Data'], + output_data_config_path=execution_input['OutputPath'], + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Training' + }, + tags=DEFAULT_TAGS, + ) + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AlgorithmSpecification': { + 'TrainingImage': PCA_IMAGE, + 'TrainingInputMode': 'File' + }, + 'OutputDataConfig': { + 'S3OutputPath.$': "$$.Execution.Input['OutputPath']" + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 86400 + }, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge', + 'VolumeSizeInGB': 30 + }, + 'RoleArn': EXECUTION_ROLE, + 'HyperParameters': { + 'feature_dim': '50000', + 'num_components': '10', + 'subtract_mean': 'True', + 'algorithm_mode': 'randomized', + 'mini_batch_size': '200' + }, + 'InputDataConfig': [ + { + 'ChannelName': 'training', + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri.$': "$$.Execution.Input['Data']" + } + } + } + ], + 'ExperimentConfig': { + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Training' + }, + 'TrainingJobName.$': "$['JobName']", + 'Tags': DEFAULT_TAGS_LIST + }, + 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', + 'End': True + } + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):