From 927b24fb3dd7bd00eeb3809f02079192e6416246 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Fri, 16 Jul 2021 10:06:11 -0700 Subject: [PATCH 01/20] documentation: Add setup instructions to run/debug tests locally --- CONTRIBUTING.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 99a078b..1228768 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,8 @@ information to effectively respond to your bug report or contribution. * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) +* [Setting Up Your Development Environment](#setting-up-your-development-environment) + * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -168,6 +170,29 @@ Please remember to: Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws/aws-step-functions-data-science-sdk-python/labels/help%20wanted) issues is a great place to start. +## Setting Up Your Development Environment + +Setting up your IDE for debugging your tests locally will save you a lot of time. +You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, +follow the steps described in this section. + +### PyCharm +1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ +1. Go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: + + | Option | Value | + |------------------------------------------------------------:|:----------------------| + | Attach subprocess automatically while debugging | `Enabled` | + | Collect run-time types information for code insight | `Enabled` | + | Gevent compatible | `Disabled` | + | Drop into debugger on failed tests | `Enabled` | + | PyQt compatible | `Auto` | + | For Attach to Process show processes with names containing | `python` | +1. Right click on a test or test file and select `Run/Debug` + + _Note: Can also be done by clicking on green arrow next to test definition_ + + ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). From a7700a6c6a795b7d76a39f51e915a143bed7ea05 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 9 Aug 2021 18:48:20 -0700 Subject: [PATCH 02/20] Added sub section for debug setup and linked to run tests instructions --- CONTRIBUTING.md | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1228768..dd02612 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,13 +14,15 @@ information to effectively respond to your bug report or contribution. * [Contributing via Pull Requests (PRs)](#contributing-via-pull-requests-prs) * [Pulling Down the Code](#pulling-down-the-code) * [Running the Unit Tests](#running-the-unit-tests) + * [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) * [Running the Integration Tests](#running-the-integration-tests) * [Making and Testing Your Change](#making-and-testing-your-change) * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) * [Setting Up Your Development Environment](#setting-up-your-development-environment) - * [PyCharm](#pycharm) + * [Setting Up Your Environment for Debugging](#setting-up-your-environment-for-debugging) + * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -67,6 +69,11 @@ You can also run a single test with the following command: `tox -e py36 -- -s -v * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` * Example: `export IGNORE_COVERAGE=- ; tox -e py36 -- -s -vv tests/unit/test_sagemaker_steps.py::test_training_step_creation_with_model ; unset IGNORE_COVERAGE` +#### Running Unit Tests and Debugging in PyCharm +You can also run the unit tests with the following options: +* Right click on a test file in the Project tree and select `Run/Debug 'pytest' for ...` +* Right click on the test definition and select `Run/Debug 'pytest' for ...` +* Click on the green arrow next to test definition ### Running the Integration Tests @@ -172,13 +179,15 @@ Looking at the existing issues is a great way to find something to contribute on ## Setting Up Your Development Environment -Setting up your IDE for debugging your tests locally will save you a lot of time. +### Setting Up Your Environment for Debugging + +Setting up your IDE for debugging tests locally will save you a lot of time. You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, follow the steps described in this section. -### PyCharm +#### PyCharm 1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ -1. Go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: +1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: | Option | Value | |------------------------------------------------------------:|:----------------------| @@ -188,10 +197,11 @@ follow the steps described in this section. | Drop into debugger on failed tests | `Enabled` | | PyQt compatible | `Auto` | | For Attach to Process show processes with names containing | `python` | -1. Right click on a test or test file and select `Run/Debug` - - _Note: Can also be done by clicking on green arrow next to test definition_ + This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. +1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) +_Note: This setup was tested and confirmed to work with +`PyCharm 2020.3.5 (Professional Edition)` and `PyCharm 2021.1.1 (Professional Edition)`_ ## Code of Conduct From 6b6443aac0ba61b444bf56ef0ed135f58f2c4f7b Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 11 Aug 2021 17:55:48 -0700 Subject: [PATCH 03/20] Update table --- CONTRIBUTING.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dd02612..6a0f342 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -190,14 +190,15 @@ follow the steps described in this section. 1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: | Option | Value | - |------------------------------------------------------------:|:----------------------| + |:------------------------------------------------------------ |:----------------------| | Attach subprocess automatically while debugging | `Enabled` | | Collect run-time types information for code insight | `Enabled` | | Gevent compatible | `Disabled` | | Drop into debugger on failed tests | `Enabled` | | PyQt compatible | `Auto` | | For Attach to Process show processes with names containing | `python` | - This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. + + This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. 1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) _Note: This setup was tested and confirmed to work with From 7f6ef3052c19afe15f12073d67a047d9a76f3de6 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 10:32:04 -0700 Subject: [PATCH 04/20] Support placeholders for processor parameters in processingstep --- src/stepfunctions/exceptions.py | 8 +- src/stepfunctions/steps/constants.py | 30 +++++++ src/stepfunctions/steps/fields.py | 15 +++- src/stepfunctions/steps/sagemaker.py | 120 +++++++++++++++++++++++++-- tests/integ/test_sagemaker_steps.py | 84 +++++++++++++++++++ tests/unit/test_sagemaker_steps.py | 118 +++++++++++++++++++++++++- 6 files changed, 365 insertions(+), 10 deletions(-) create mode 100644 src/stepfunctions/steps/constants.py diff --git a/src/stepfunctions/exceptions.py b/src/stepfunctions/exceptions.py index 7e9a4d7..56be3ea 100644 --- a/src/stepfunctions/exceptions.py +++ b/src/stepfunctions/exceptions.py @@ -22,4 +22,10 @@ class MissingRequiredParameter(Exception): class DuplicateStatesInChain(Exception): - pass \ No newline at end of file + pass + + +class InvalidPathToPlaceholderParameter(Exception): + + def __init__(self, message): + super(InvalidPathToPlaceholderParameter, self).__init__(message) diff --git a/src/stepfunctions/steps/constants.py b/src/stepfunctions/steps/constants.py new file mode 100644 index 0000000..1b308c8 --- /dev/null +++ b/src/stepfunctions/steps/constants.py @@ -0,0 +1,30 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from enum import Enum +from stepfunctions.steps.fields import Field + +# Path to SageMaker placeholder parameters +placeholder_paths = { + # Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html + 'ProcessingStep': { + Field.Role.value: ['RoleArn'], + Field.ImageUri.value: ['AppSpecification', 'ImageUri'], + Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], + Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'], + Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'], + Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], + Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'], + Field.Env.value: ['Environment'], + Field.Tags.value: ['Tags'], + } +} diff --git a/src/stepfunctions/steps/fields.py b/src/stepfunctions/steps/fields.py index 24c3949..fab4aa6 100644 --- a/src/stepfunctions/steps/fields.py +++ b/src/stepfunctions/steps/fields.py @@ -59,10 +59,23 @@ class Field(Enum): HeartbeatSeconds = 'heartbeat_seconds' HeartbeatSecondsPath = 'heartbeat_seconds_path' - # Retry and catch fields ErrorEquals = 'error_equals' IntervalSeconds = 'interval_seconds' MaxAttempts = 'max_attempts' BackoffRate = 'backoff_rate' NextStep = 'next_step' + + # Sagemaker step fields + # Processing Step: Processor + Role = 'role' + ImageUri = 'image_uri' + InstanceCount = 'instance_count' + InstanceType = 'instance_type' + Entrypoint = 'entrypoint' + VolumeSizeInGB = 'volume_size_in_gb' + VolumeKMSKey = 'volume_kms_key' + OutputKMSKey = 'output_kms_key' + MaxRuntimeInSeconds = 'max_runtime_in_seconds' + Env = 'env' + Tags = 'tags' \ No newline at end of file diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 30e3d7c..d2f3740 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -13,10 +13,14 @@ from __future__ import absolute_import import logging +import operator from enum import Enum +from functools import reduce +from stepfunctions.exceptions import InvalidPathToPlaceholderParameter from stepfunctions.inputs import Placeholder +from stepfunctions.steps.constants import placeholder_paths from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field from stepfunctions.steps.utils import tags_dict_to_kv_list @@ -25,6 +29,7 @@ from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig +from sagemaker.processing import ProcessingJob logger = logging.getLogger('stepfunctions.sagemaker') @@ -41,6 +46,104 @@ class SageMakerApi(Enum): CreateProcessingJob = "createProcessingJob" +class SageMakerTask(Task): + + """ + Task State causes the interpreter to execute the work identified by the state’s `resource` field. + """ + + def __init__(self, state_id, step_type, tags, **kwargs): + """ + Args: + state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. + timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) + timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. + heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. + heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. + comment (str, optional): Human-readable comment or description. (default: None) + input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') + parameters (dict, optional): The value of this field becomes the effective input for the state. + result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') + output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') + """ + self._replace_sagemaker_placeholders(step_type, kwargs) + if tags: + self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type) + + super(SageMakerTask, self).__init__(state_id, **kwargs) + + + def allowed_fields(self): + sagemaker_fields = [ + # ProcessingStep: Processor + Field.Role, + Field.ImageUri, + Field.InstanceCount, + Field.InstanceType, + Field.Entrypoint, + Field.VolumeSizeInGB, + Field.VolumeKMSKey, + Field.OutputKMSKey, + Field.MaxRuntimeInSeconds, + Field.Env, + Field.Tags, + ] + + return super(SageMakerTask, self).allowed_fields() + sagemaker_fields + + + def _replace_sagemaker_placeholders(self, step_type, args): + # Fetch path from type + sagemaker_parameters = args[Field.Parameters.value] + paths = placeholder_paths.get(step_type) + treated_args = [] + + for arg_name, value in args.items(): + if arg_name in [Field.Parameters.value]: + continue + if arg_name in paths.keys(): + path = paths.get(arg_name) + if self._set_placeholder(sagemaker_parameters, path, value, arg_name): + treated_args.append(arg_name) + + SageMakerTask.remove_treated_args(treated_args, args) + + @staticmethod + def get_value_from_path(parameters, path): + value_from_path = reduce(operator.getitem, path, parameters) + return value_from_path + # return reduce(operator.getitem, path, parameters) + + @staticmethod + def _set_placeholder(parameters, path, value, arg_name): + is_set = False + try: + SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value + is_set = True + except KeyError as e: + message = f"Invalid path {path} for {arg_name}: {e}" + raise InvalidPathToPlaceholderParameter(message) + return is_set + + @staticmethod + def remove_treated_args(treated_args, args): + for treated_arg in treated_args: + try: + del args[treated_arg] + except KeyError as e: + pass + + def set_tags_config(self, tags, parameters, step_type): + if isinstance(tags, Placeholder): + # Replace with placeholder + path = placeholder_paths.get(step_type).get(Field.Tags.value) + if path: + self._set_placeholder(parameters, path, tags, Field.Tags.value) + else: + parameters['Tags'] = tags_dict_to_kv_list(tags) + + class TrainingStep(Task): """ @@ -473,13 +576,15 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta super(TuningStep, self).__init__(state_id, **kwargs) -class ProcessingStep(Task): +class ProcessingStep(SageMakerTask): """ Creates a Task State to execute a SageMaker Processing Job. """ - def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, + container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, + tags=None, max_runtime_in_seconds=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -499,7 +604,8 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) - tags (list[dict], optional): `List to tags `_ to associate with the resource. + tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. + max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job """ if wait_for_completion: """ @@ -528,12 +634,12 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp if experiment_config is not None: parameters['ExperimentConfig'] = experiment_config - if tags: - parameters['Tags'] = tags_dict_to_kv_list(tags) - if 'S3Operations' in parameters: del parameters['S3Operations'] + + if max_runtime_in_seconds: + parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds) kwargs[Field.Parameters.value] = parameters - super(ProcessingStep, self).__init__(state_id, **kwargs) + super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs) diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 63c060a..6400529 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -29,7 +29,9 @@ from sagemaker.tuner import HyperparameterTuner from sagemaker.processing import ProcessingInput, ProcessingOutput +from stepfunctions.inputs import ExecutionInput from stepfunctions.steps import Chain +from stepfunctions.steps.fields import Field from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow @@ -352,3 +354,85 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn) # End of Cleanup + + +def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, + sagemaker_role_arn): + region = boto3.session.Session().region_name + input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region) + + input_s3 = sagemaker_session.upload_data( + path=os.path.join(DATA_DIR, 'sklearn_processing'), + bucket=sagemaker_session.default_bucket(), + key_prefix='integ-test-data/sklearn_processing/code' + ) + + output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing' + + inputs = [ + ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), + ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', + input_name='code'), + ] + + outputs = [ + ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', + output_name='train_data'), + ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', + output_name='test_data'), + ] + + # Build workflow definition + execution_input = ExecutionInput(schema={ + Field.ImageUri.value: str, + Field.InstanceCount.value: int, + Field.Entrypoint.value: str, + Field.Role.value: str, + Field.VolumeSizeInGB.value: int, + Field.MaxRuntimeInSeconds.value: int + }) + + job_name = generate_job_name() + processing_step = ProcessingStep('create_processing_job_step', + processor=sklearn_processor_fixture, + job_name=job_name, + inputs=inputs, + outputs=outputs, + container_arguments=['--train-test-split-ratio', '0.2'], + container_entrypoint=execution_input[Field.Entrypoint.value], + image_uri=execution_input[Field.ImageUri.value], + instance_count=execution_input[Field.InstanceCount.value], + role=execution_input[Field.Role.value], + volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], + max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value] + ) + workflow_graph = Chain([processing_step]) + + with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): + # Create workflow and check definition + workflow = create_workflow_and_check_definition( + workflow_graph=workflow_graph, + workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), + sfn_client=sfn_client, + sfn_role_arn=sfn_role_arn + ) + + execution_input = { + Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', + Field.InstanceCount.value: 1, + Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'], + Field.Role.value: sagemaker_role_arn, + Field.VolumeSizeInGB.value: 30, + Field.MaxRuntimeInSeconds.value: 500 + } + + # Execute workflow + execution = workflow.execute(inputs=execution_input) + execution_output = execution.get_output(wait=True) + + # Check workflow output + assert execution_output.get("ProcessingJobStatus") == "Completed" + + # Cleanup + state_machine_delete_wait(sfn_client, workflow.state_machine_arn) + # End of Cleanup diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index c643468..33da0ca 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -27,7 +27,9 @@ from unittest.mock import MagicMock, patch from stepfunctions.inputs import ExecutionInput, StepInput -from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep +from stepfunctions.steps.fields import Field +from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\ + ProcessingStep from stepfunctions.steps.sagemaker import tuning_config from tests.unit.utils import mock_boto_api_call @@ -962,3 +964,117 @@ def test_processing_step_creation(sklearn_processor): 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', 'End': True } + + +def test_processing_step_creation_with_placeholders(sklearn_processor): + execution_input = ExecutionInput(schema={ + Field.ImageUri.value: str, + Field.InstanceCount.value: int, + Field.Entrypoint.value: str, + Field.OutputKMSKey.value: str, + Field.Role.value: str, + Field.Env.value: str, + Field.VolumeSizeInGB.value: int, + Field.VolumeKMSKey.value: str, + Field.MaxRuntimeInSeconds.value: int, + Field.Tags.value: [{str: str}] + }) + + step_input = StepInput(schema={ + Field.InstanceType.value: str + }) + + inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] + outputs = [ + ProcessingOutput(source='/opt/ml/processing/output/train'), + ProcessingOutput(source='/opt/ml/processing/output/validation'), + ProcessingOutput(source='/opt/ml/processing/output/test') + ] + step = ProcessingStep( + 'Feature Transformation', + sklearn_processor, + 'MyProcessingJob', + container_entrypoint=execution_input[Field.Entrypoint.value], + kms_key_id=execution_input[Field.OutputKMSKey.value], + inputs=inputs, + outputs=outputs, + image_uri=execution_input[Field.ImageUri.value], + instance_count=execution_input[Field.InstanceCount.value], + instance_type=step_input[Field.InstanceType.value], + role=execution_input[Field.Role.value], + env=execution_input[Field.Env.value], + volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], + volume_kms_key=execution_input[Field.VolumeKMSKey.value], + max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value], + tags=execution_input[Field.Tags.value], + ) + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AppSpecification': { + 'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']", + 'ImageUri.$': "$$.Execution.Input['image_uri']" + }, + 'Environment.$': "$$.Execution.Input['env']", + 'ProcessingInputs': [ + { + 'InputName': None, + 'AppManaged': False, + 'S3Input': { + 'LocalPath': '/opt/ml/processing/input', + 'S3CompressionType': 'None', + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3InputMode': 'File', + 'S3Uri': 'dataset.csv' + } + } + ], + 'ProcessingOutputConfig': { + 'KmsKeyId.$': "$$.Execution.Input['output_kms_key']", + 'Outputs': [ + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/train', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/validation', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/test', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + } + ] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount.$': "$$.Execution.Input['instance_count']", + 'InstanceType.$': "$['instance_type']", + 'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']", + 'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']" + } + }, + 'ProcessingJobName': 'MyProcessingJob', + 'RoleArn.$': "$$.Execution.Input['role']", + 'Tags.$': "$$.Execution.Input['tags']", + 'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"}, + }, + 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', + 'End': True + } From 00830f390d9f49475bf702aba2e96c52d9fc23c1 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 12:30:04 -0700 Subject: [PATCH 05/20] Added doc --- src/stepfunctions/steps/sagemaker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index d2f3740..e05742c 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -580,6 +580,10 @@ class ProcessingStep(SageMakerTask): """ Creates a Task State to execute a SageMaker Processing Job. + + The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically + in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type, + volume_size_in_gb, volume_kms_key, output_kms_key """ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, From c708da79924e0094174844c053ada0fe4ec8bfba Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 12:45:29 -0700 Subject: [PATCH 06/20] Removed contibuting changes(included in another pr) --- CONTRIBUTING.md | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6a0f342..99a078b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,15 +14,11 @@ information to effectively respond to your bug report or contribution. * [Contributing via Pull Requests (PRs)](#contributing-via-pull-requests-prs) * [Pulling Down the Code](#pulling-down-the-code) * [Running the Unit Tests](#running-the-unit-tests) - * [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) * [Running the Integration Tests](#running-the-integration-tests) * [Making and Testing Your Change](#making-and-testing-your-change) * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) -* [Setting Up Your Development Environment](#setting-up-your-development-environment) - * [Setting Up Your Environment for Debugging](#setting-up-your-environment-for-debugging) - * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -69,11 +65,6 @@ You can also run a single test with the following command: `tox -e py36 -- -s -v * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` * Example: `export IGNORE_COVERAGE=- ; tox -e py36 -- -s -vv tests/unit/test_sagemaker_steps.py::test_training_step_creation_with_model ; unset IGNORE_COVERAGE` -#### Running Unit Tests and Debugging in PyCharm -You can also run the unit tests with the following options: -* Right click on a test file in the Project tree and select `Run/Debug 'pytest' for ...` -* Right click on the test definition and select `Run/Debug 'pytest' for ...` -* Click on the green arrow next to test definition ### Running the Integration Tests @@ -177,33 +168,6 @@ Please remember to: Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws/aws-step-functions-data-science-sdk-python/labels/help%20wanted) issues is a great place to start. -## Setting Up Your Development Environment - -### Setting Up Your Environment for Debugging - -Setting up your IDE for debugging tests locally will save you a lot of time. -You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, -follow the steps described in this section. - -#### PyCharm -1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ -1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: - - | Option | Value | - |:------------------------------------------------------------ |:----------------------| - | Attach subprocess automatically while debugging | `Enabled` | - | Collect run-time types information for code insight | `Enabled` | - | Gevent compatible | `Disabled` | - | Drop into debugger on failed tests | `Enabled` | - | PyQt compatible | `Auto` | - | For Attach to Process show processes with names containing | `python` | - - This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. -1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) - -_Note: This setup was tested and confirmed to work with -`PyCharm 2020.3.5 (Professional Edition)` and `PyCharm 2021.1.1 (Professional Edition)`_ - ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). From 2ea9e1fa3fd184f0163ff0ee3d3c9633066068f0 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 16 Aug 2021 17:55:47 -0700 Subject: [PATCH 07/20] Merge sagemaker generated parameters with placeholder compatible parameters received in args --- src/stepfunctions/steps/constants.py | 30 ------ src/stepfunctions/steps/fields.py | 14 --- src/stepfunctions/steps/sagemaker.py | 148 +++++---------------------- src/stepfunctions/steps/utils.py | 22 ++++ tests/integ/test_sagemaker_steps.py | 54 ++++++---- tests/unit/test_sagemaker_steps.py | 63 ++++++++---- 6 files changed, 122 insertions(+), 209 deletions(-) delete mode 100644 src/stepfunctions/steps/constants.py diff --git a/src/stepfunctions/steps/constants.py b/src/stepfunctions/steps/constants.py deleted file mode 100644 index 1b308c8..0000000 --- a/src/stepfunctions/steps/constants.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. -from enum import Enum -from stepfunctions.steps.fields import Field - -# Path to SageMaker placeholder parameters -placeholder_paths = { - # Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html - 'ProcessingStep': { - Field.Role.value: ['RoleArn'], - Field.ImageUri.value: ['AppSpecification', 'ImageUri'], - Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'], - Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'], - Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], - Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'], - Field.Env.value: ['Environment'], - Field.Tags.value: ['Tags'], - } -} diff --git a/src/stepfunctions/steps/fields.py b/src/stepfunctions/steps/fields.py index fab4aa6..8eb102d 100644 --- a/src/stepfunctions/steps/fields.py +++ b/src/stepfunctions/steps/fields.py @@ -65,17 +65,3 @@ class Field(Enum): MaxAttempts = 'max_attempts' BackoffRate = 'backoff_rate' NextStep = 'next_step' - - # Sagemaker step fields - # Processing Step: Processor - Role = 'role' - ImageUri = 'image_uri' - InstanceCount = 'instance_count' - InstanceType = 'instance_type' - Entrypoint = 'entrypoint' - VolumeSizeInGB = 'volume_size_in_gb' - VolumeKMSKey = 'volume_kms_key' - OutputKMSKey = 'output_kms_key' - MaxRuntimeInSeconds = 'max_runtime_in_seconds' - Env = 'env' - Tags = 'tags' \ No newline at end of file diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index e05742c..03ee401 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -13,28 +13,24 @@ from __future__ import absolute_import import logging -import operator from enum import Enum -from functools import reduce -from stepfunctions.exceptions import InvalidPathToPlaceholderParameter from stepfunctions.inputs import Placeholder -from stepfunctions.steps.constants import placeholder_paths from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig -from sagemaker.processing import ProcessingJob logger = logging.getLogger('stepfunctions.sagemaker') SAGEMAKER_SERVICE_NAME = "sagemaker" + class SageMakerApi(Enum): CreateTrainingJob = "createTrainingJob" CreateTransformJob = "createTransformJob" @@ -46,104 +42,6 @@ class SageMakerApi(Enum): CreateProcessingJob = "createProcessingJob" -class SageMakerTask(Task): - - """ - Task State causes the interpreter to execute the work identified by the state’s `resource` field. - """ - - def __init__(self, state_id, step_type, tags, **kwargs): - """ - Args: - state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. - timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) - timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. - heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. - heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. - comment (str, optional): Human-readable comment or description. (default: None) - input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') - parameters (dict, optional): The value of this field becomes the effective input for the state. - result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') - output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') - """ - self._replace_sagemaker_placeholders(step_type, kwargs) - if tags: - self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type) - - super(SageMakerTask, self).__init__(state_id, **kwargs) - - - def allowed_fields(self): - sagemaker_fields = [ - # ProcessingStep: Processor - Field.Role, - Field.ImageUri, - Field.InstanceCount, - Field.InstanceType, - Field.Entrypoint, - Field.VolumeSizeInGB, - Field.VolumeKMSKey, - Field.OutputKMSKey, - Field.MaxRuntimeInSeconds, - Field.Env, - Field.Tags, - ] - - return super(SageMakerTask, self).allowed_fields() + sagemaker_fields - - - def _replace_sagemaker_placeholders(self, step_type, args): - # Fetch path from type - sagemaker_parameters = args[Field.Parameters.value] - paths = placeholder_paths.get(step_type) - treated_args = [] - - for arg_name, value in args.items(): - if arg_name in [Field.Parameters.value]: - continue - if arg_name in paths.keys(): - path = paths.get(arg_name) - if self._set_placeholder(sagemaker_parameters, path, value, arg_name): - treated_args.append(arg_name) - - SageMakerTask.remove_treated_args(treated_args, args) - - @staticmethod - def get_value_from_path(parameters, path): - value_from_path = reduce(operator.getitem, path, parameters) - return value_from_path - # return reduce(operator.getitem, path, parameters) - - @staticmethod - def _set_placeholder(parameters, path, value, arg_name): - is_set = False - try: - SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value - is_set = True - except KeyError as e: - message = f"Invalid path {path} for {arg_name}: {e}" - raise InvalidPathToPlaceholderParameter(message) - return is_set - - @staticmethod - def remove_treated_args(treated_args, args): - for treated_arg in treated_args: - try: - del args[treated_arg] - except KeyError as e: - pass - - def set_tags_config(self, tags, parameters, step_type): - if isinstance(tags, Placeholder): - # Replace with placeholder - path = placeholder_paths.get(step_type).get(Field.Tags.value) - if path: - self._set_placeholder(parameters, path, tags, Field.Tags.value) - else: - parameters['Tags'] = tags_dict_to_kv_list(tags) - - class TrainingStep(Task): """ @@ -576,7 +474,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta super(TuningStep, self).__init__(state_id, **kwargs) -class ProcessingStep(SageMakerTask): +class ProcessingStep(Task): """ Creates a Task State to execute a SageMaker Processing Job. @@ -588,7 +486,7 @@ class ProcessingStep(SageMakerTask): def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, - tags=None, max_runtime_in_seconds=None, **kwargs): + tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -600,16 +498,16 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for the processing job. These can be specified as either path strings or :class:`~sagemaker.processing.ProcessingOutput` objects (default: None). - experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None) - container_arguments ([str]): The arguments for a container used to run a processing job. - container_entrypoint ([str]): The entrypoint for a container used to run a processing job. - kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker + experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None) + container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job. + container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job. + kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. - max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job + parameters(dict, optional): The value of this field becomes the effective input for the state. """ if wait_for_completion: """ @@ -628,22 +526,26 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp SageMakerApi.CreateProcessingJob) if isinstance(job_name, str): - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) else: - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) if isinstance(job_name, Placeholder): - parameters['ProcessingJobName'] = job_name + processing_parameters['ProcessingJobName'] = job_name if experiment_config is not None: - parameters['ExperimentConfig'] = experiment_config - - if 'S3Operations' in parameters: - del parameters['S3Operations'] + processing_parameters['ExperimentConfig'] = experiment_config - if max_runtime_in_seconds: - parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds) - - kwargs[Field.Parameters.value] = parameters + if tags: + processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) + + if 'S3Operations' in processing_parameters: + del processing_parameters['S3Operations'] + + if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): + # Update processing_parameters with input parameters + merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters", + "Input Parameters") - super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs) + kwargs[Field.Parameters.value] = processing_parameters + super(ProcessingStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 6f44481..b45d107 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -14,6 +14,7 @@ import boto3 import logging +from stepfunctions.inputs import Placeholder logger = logging.getLogger('stepfunctions') @@ -45,3 +46,24 @@ def get_aws_partition(): return cur_partition return cur_partition + + +def merge_dicts(first, second, first_name, second_name): + """ + Merges first and second dictionaries into the first one. + Values in the first dict are updated with the values of the second one. + """ + if all(isinstance(d, dict) for d in [first, second]): + for key, value in second.items(): + if key in first: + if isinstance(first[key], dict) and isinstance(second[key], dict): + merge_dicts(first[key], second[key], first_name, second_name) + elif first[key] is value: + pass + else: + logger.info( + f"{first_name} property: <{key}> with value: <{first[key]}>" + f" will be overwritten with value provided in {second_name} : <{value}>") + first[key] = second[key] + else: + first[key] = second[key] diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 6400529..77bca39 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -31,7 +31,6 @@ from stepfunctions.inputs import ExecutionInput from stepfunctions.steps import Chain -from stepfunctions.steps.fields import Field from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow @@ -384,27 +383,41 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ # Build workflow definition execution_input = ExecutionInput(schema={ - Field.ImageUri.value: str, - Field.InstanceCount.value: int, - Field.Entrypoint.value: str, - Field.Role.value: str, - Field.VolumeSizeInGB.value: int, - Field.MaxRuntimeInSeconds.value: int + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'role': str, + 'volume_size_in_gb': int, + 'max_runtime_in_seconds': int, + 'container_arguments': [str], }) + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + } + } + job_name = generate_job_name() processing_step = ProcessingStep('create_processing_job_step', processor=sklearn_processor_fixture, job_name=job_name, inputs=inputs, outputs=outputs, - container_arguments=['--train-test-split-ratio', '0.2'], - container_entrypoint=execution_input[Field.Entrypoint.value], - image_uri=execution_input[Field.ImageUri.value], - instance_count=execution_input[Field.InstanceCount.value], - role=execution_input[Field.Role.value], - volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], - max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value] + container_arguments=execution_input['container_arguments'], + container_entrypoint=execution_input['entrypoint'], + parameters=parameters ) workflow_graph = Chain([processing_step]) @@ -418,12 +431,13 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ ) execution_input = { - Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', - Field.InstanceCount.value: 1, - Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'], - Field.Role.value: sagemaker_role_arn, - Field.VolumeSizeInGB.value: 30, - Field.MaxRuntimeInSeconds.value: 500 + 'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', + 'instance_count': 1, + 'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'], + 'role': sagemaker_role_arn, + 'volume_size_in_gb': 30, + 'max_runtime_in_seconds': 500, + 'container_arguments': ['--train-test-split-ratio', '0.2'] } # Execute workflow diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 33da0ca..664a498 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -968,22 +968,47 @@ def test_processing_step_creation(sklearn_processor): def test_processing_step_creation_with_placeholders(sklearn_processor): execution_input = ExecutionInput(schema={ - Field.ImageUri.value: str, - Field.InstanceCount.value: int, - Field.Entrypoint.value: str, - Field.OutputKMSKey.value: str, - Field.Role.value: str, - Field.Env.value: str, - Field.VolumeSizeInGB.value: int, - Field.VolumeKMSKey.value: str, - Field.MaxRuntimeInSeconds.value: int, - Field.Tags.value: [{str: str}] + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'output_kms_key': str, + 'role': str, + 'env': str, + 'volume_size_in_gb': int, + 'volume_kms_key': str, + 'max_runtime_in_seconds': int, + 'tags': [{str: str}], + 'container_arguments': [str] }) step_input = StepInput(schema={ - Field.InstanceType.value: str + 'instance_type': str }) + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'Environment': execution_input['env'], + 'ProcessingOutputConfig': { + 'KmsKeyId': execution_input['output_kms_key'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'InstanceType': step_input['instance_type'], + 'VolumeKmsKeyId': execution_input['volume_kms_key'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + }, + 'Tags': execution_input['tags'] + } + inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] outputs = [ ProcessingOutput(source='/opt/ml/processing/output/train'), @@ -994,24 +1019,18 @@ def test_processing_step_creation_with_placeholders(sklearn_processor): 'Feature Transformation', sklearn_processor, 'MyProcessingJob', - container_entrypoint=execution_input[Field.Entrypoint.value], - kms_key_id=execution_input[Field.OutputKMSKey.value], + container_entrypoint=execution_input['entrypoint'], + container_arguments=execution_input['container_arguments'], + kms_key_id=execution_input['output_kms_key'], inputs=inputs, outputs=outputs, - image_uri=execution_input[Field.ImageUri.value], - instance_count=execution_input[Field.InstanceCount.value], - instance_type=step_input[Field.InstanceType.value], - role=execution_input[Field.Role.value], - env=execution_input[Field.Env.value], - volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], - volume_kms_key=execution_input[Field.VolumeKMSKey.value], - max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value], - tags=execution_input[Field.Tags.value], + parameters=parameters ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'AppSpecification': { + 'ContainerArguments.$': "$$.Execution.Input['container_arguments']", 'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']", 'ImageUri.$': "$$.Execution.Input['image_uri']" }, From 17543edea795ada13d2472932585dbbc441a0b06 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Fri, 16 Jul 2021 10:06:11 -0700 Subject: [PATCH 08/20] documentation: Add setup instructions to run/debug tests locally --- CONTRIBUTING.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 99a078b..1228768 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,8 @@ information to effectively respond to your bug report or contribution. * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) +* [Setting Up Your Development Environment](#setting-up-your-development-environment) + * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -168,6 +170,29 @@ Please remember to: Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws/aws-step-functions-data-science-sdk-python/labels/help%20wanted) issues is a great place to start. +## Setting Up Your Development Environment + +Setting up your IDE for debugging your tests locally will save you a lot of time. +You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, +follow the steps described in this section. + +### PyCharm +1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ +1. Go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: + + | Option | Value | + |------------------------------------------------------------:|:----------------------| + | Attach subprocess automatically while debugging | `Enabled` | + | Collect run-time types information for code insight | `Enabled` | + | Gevent compatible | `Disabled` | + | Drop into debugger on failed tests | `Enabled` | + | PyQt compatible | `Auto` | + | For Attach to Process show processes with names containing | `python` | +1. Right click on a test or test file and select `Run/Debug` + + _Note: Can also be done by clicking on green arrow next to test definition_ + + ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). From 36e2ee8fbb5d63a17397db83b4ba9b9451b26e25 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 9 Aug 2021 18:48:20 -0700 Subject: [PATCH 09/20] Added sub section for debug setup and linked to run tests instructions --- CONTRIBUTING.md | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1228768..dd02612 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,13 +14,15 @@ information to effectively respond to your bug report or contribution. * [Contributing via Pull Requests (PRs)](#contributing-via-pull-requests-prs) * [Pulling Down the Code](#pulling-down-the-code) * [Running the Unit Tests](#running-the-unit-tests) + * [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) * [Running the Integration Tests](#running-the-integration-tests) * [Making and Testing Your Change](#making-and-testing-your-change) * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) * [Setting Up Your Development Environment](#setting-up-your-development-environment) - * [PyCharm](#pycharm) + * [Setting Up Your Environment for Debugging](#setting-up-your-environment-for-debugging) + * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -67,6 +69,11 @@ You can also run a single test with the following command: `tox -e py36 -- -s -v * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` * Example: `export IGNORE_COVERAGE=- ; tox -e py36 -- -s -vv tests/unit/test_sagemaker_steps.py::test_training_step_creation_with_model ; unset IGNORE_COVERAGE` +#### Running Unit Tests and Debugging in PyCharm +You can also run the unit tests with the following options: +* Right click on a test file in the Project tree and select `Run/Debug 'pytest' for ...` +* Right click on the test definition and select `Run/Debug 'pytest' for ...` +* Click on the green arrow next to test definition ### Running the Integration Tests @@ -172,13 +179,15 @@ Looking at the existing issues is a great way to find something to contribute on ## Setting Up Your Development Environment -Setting up your IDE for debugging your tests locally will save you a lot of time. +### Setting Up Your Environment for Debugging + +Setting up your IDE for debugging tests locally will save you a lot of time. You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, follow the steps described in this section. -### PyCharm +#### PyCharm 1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ -1. Go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: +1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: | Option | Value | |------------------------------------------------------------:|:----------------------| @@ -188,10 +197,11 @@ follow the steps described in this section. | Drop into debugger on failed tests | `Enabled` | | PyQt compatible | `Auto` | | For Attach to Process show processes with names containing | `python` | -1. Right click on a test or test file and select `Run/Debug` - - _Note: Can also be done by clicking on green arrow next to test definition_ + This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. +1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) +_Note: This setup was tested and confirmed to work with +`PyCharm 2020.3.5 (Professional Edition)` and `PyCharm 2021.1.1 (Professional Edition)`_ ## Code of Conduct From ea40f7caed8ec0de43c1f175886a2b22bfdf1460 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 11 Aug 2021 17:55:48 -0700 Subject: [PATCH 10/20] Update table --- CONTRIBUTING.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dd02612..6a0f342 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -190,14 +190,15 @@ follow the steps described in this section. 1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: | Option | Value | - |------------------------------------------------------------:|:----------------------| + |:------------------------------------------------------------ |:----------------------| | Attach subprocess automatically while debugging | `Enabled` | | Collect run-time types information for code insight | `Enabled` | | Gevent compatible | `Disabled` | | Drop into debugger on failed tests | `Enabled` | | PyQt compatible | `Auto` | | For Attach to Process show processes with names containing | `python` | - This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. + + This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. 1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) _Note: This setup was tested and confirmed to work with From e4991087bb0174c86e5de6303e7d70da014cf056 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 10:32:04 -0700 Subject: [PATCH 11/20] Support placeholders for processor parameters in processingstep --- src/stepfunctions/exceptions.py | 8 +- src/stepfunctions/steps/constants.py | 30 +++++++ src/stepfunctions/steps/fields.py | 15 +++- src/stepfunctions/steps/sagemaker.py | 120 +++++++++++++++++++++++++-- tests/integ/test_sagemaker_steps.py | 84 +++++++++++++++++++ tests/unit/test_sagemaker_steps.py | 118 +++++++++++++++++++++++++- 6 files changed, 365 insertions(+), 10 deletions(-) create mode 100644 src/stepfunctions/steps/constants.py diff --git a/src/stepfunctions/exceptions.py b/src/stepfunctions/exceptions.py index 7e9a4d7..56be3ea 100644 --- a/src/stepfunctions/exceptions.py +++ b/src/stepfunctions/exceptions.py @@ -22,4 +22,10 @@ class MissingRequiredParameter(Exception): class DuplicateStatesInChain(Exception): - pass \ No newline at end of file + pass + + +class InvalidPathToPlaceholderParameter(Exception): + + def __init__(self, message): + super(InvalidPathToPlaceholderParameter, self).__init__(message) diff --git a/src/stepfunctions/steps/constants.py b/src/stepfunctions/steps/constants.py new file mode 100644 index 0000000..1b308c8 --- /dev/null +++ b/src/stepfunctions/steps/constants.py @@ -0,0 +1,30 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from enum import Enum +from stepfunctions.steps.fields import Field + +# Path to SageMaker placeholder parameters +placeholder_paths = { + # Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html + 'ProcessingStep': { + Field.Role.value: ['RoleArn'], + Field.ImageUri.value: ['AppSpecification', 'ImageUri'], + Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], + Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'], + Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'], + Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], + Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'], + Field.Env.value: ['Environment'], + Field.Tags.value: ['Tags'], + } +} diff --git a/src/stepfunctions/steps/fields.py b/src/stepfunctions/steps/fields.py index 24c3949..fab4aa6 100644 --- a/src/stepfunctions/steps/fields.py +++ b/src/stepfunctions/steps/fields.py @@ -59,10 +59,23 @@ class Field(Enum): HeartbeatSeconds = 'heartbeat_seconds' HeartbeatSecondsPath = 'heartbeat_seconds_path' - # Retry and catch fields ErrorEquals = 'error_equals' IntervalSeconds = 'interval_seconds' MaxAttempts = 'max_attempts' BackoffRate = 'backoff_rate' NextStep = 'next_step' + + # Sagemaker step fields + # Processing Step: Processor + Role = 'role' + ImageUri = 'image_uri' + InstanceCount = 'instance_count' + InstanceType = 'instance_type' + Entrypoint = 'entrypoint' + VolumeSizeInGB = 'volume_size_in_gb' + VolumeKMSKey = 'volume_kms_key' + OutputKMSKey = 'output_kms_key' + MaxRuntimeInSeconds = 'max_runtime_in_seconds' + Env = 'env' + Tags = 'tags' \ No newline at end of file diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 30e3d7c..d2f3740 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -13,10 +13,14 @@ from __future__ import absolute_import import logging +import operator from enum import Enum +from functools import reduce +from stepfunctions.exceptions import InvalidPathToPlaceholderParameter from stepfunctions.inputs import Placeholder +from stepfunctions.steps.constants import placeholder_paths from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field from stepfunctions.steps.utils import tags_dict_to_kv_list @@ -25,6 +29,7 @@ from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig +from sagemaker.processing import ProcessingJob logger = logging.getLogger('stepfunctions.sagemaker') @@ -41,6 +46,104 @@ class SageMakerApi(Enum): CreateProcessingJob = "createProcessingJob" +class SageMakerTask(Task): + + """ + Task State causes the interpreter to execute the work identified by the state’s `resource` field. + """ + + def __init__(self, state_id, step_type, tags, **kwargs): + """ + Args: + state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. + timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) + timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. + heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. + heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. + comment (str, optional): Human-readable comment or description. (default: None) + input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') + parameters (dict, optional): The value of this field becomes the effective input for the state. + result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') + output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') + """ + self._replace_sagemaker_placeholders(step_type, kwargs) + if tags: + self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type) + + super(SageMakerTask, self).__init__(state_id, **kwargs) + + + def allowed_fields(self): + sagemaker_fields = [ + # ProcessingStep: Processor + Field.Role, + Field.ImageUri, + Field.InstanceCount, + Field.InstanceType, + Field.Entrypoint, + Field.VolumeSizeInGB, + Field.VolumeKMSKey, + Field.OutputKMSKey, + Field.MaxRuntimeInSeconds, + Field.Env, + Field.Tags, + ] + + return super(SageMakerTask, self).allowed_fields() + sagemaker_fields + + + def _replace_sagemaker_placeholders(self, step_type, args): + # Fetch path from type + sagemaker_parameters = args[Field.Parameters.value] + paths = placeholder_paths.get(step_type) + treated_args = [] + + for arg_name, value in args.items(): + if arg_name in [Field.Parameters.value]: + continue + if arg_name in paths.keys(): + path = paths.get(arg_name) + if self._set_placeholder(sagemaker_parameters, path, value, arg_name): + treated_args.append(arg_name) + + SageMakerTask.remove_treated_args(treated_args, args) + + @staticmethod + def get_value_from_path(parameters, path): + value_from_path = reduce(operator.getitem, path, parameters) + return value_from_path + # return reduce(operator.getitem, path, parameters) + + @staticmethod + def _set_placeholder(parameters, path, value, arg_name): + is_set = False + try: + SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value + is_set = True + except KeyError as e: + message = f"Invalid path {path} for {arg_name}: {e}" + raise InvalidPathToPlaceholderParameter(message) + return is_set + + @staticmethod + def remove_treated_args(treated_args, args): + for treated_arg in treated_args: + try: + del args[treated_arg] + except KeyError as e: + pass + + def set_tags_config(self, tags, parameters, step_type): + if isinstance(tags, Placeholder): + # Replace with placeholder + path = placeholder_paths.get(step_type).get(Field.Tags.value) + if path: + self._set_placeholder(parameters, path, tags, Field.Tags.value) + else: + parameters['Tags'] = tags_dict_to_kv_list(tags) + + class TrainingStep(Task): """ @@ -473,13 +576,15 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta super(TuningStep, self).__init__(state_id, **kwargs) -class ProcessingStep(Task): +class ProcessingStep(SageMakerTask): """ Creates a Task State to execute a SageMaker Processing Job. """ - def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, + container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, + tags=None, max_runtime_in_seconds=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -499,7 +604,8 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) - tags (list[dict], optional): `List to tags `_ to associate with the resource. + tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. + max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job """ if wait_for_completion: """ @@ -528,12 +634,12 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp if experiment_config is not None: parameters['ExperimentConfig'] = experiment_config - if tags: - parameters['Tags'] = tags_dict_to_kv_list(tags) - if 'S3Operations' in parameters: del parameters['S3Operations'] + + if max_runtime_in_seconds: + parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds) kwargs[Field.Parameters.value] = parameters - super(ProcessingStep, self).__init__(state_id, **kwargs) + super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs) diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 63c060a..6400529 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -29,7 +29,9 @@ from sagemaker.tuner import HyperparameterTuner from sagemaker.processing import ProcessingInput, ProcessingOutput +from stepfunctions.inputs import ExecutionInput from stepfunctions.steps import Chain +from stepfunctions.steps.fields import Field from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow @@ -352,3 +354,85 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn) # End of Cleanup + + +def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, + sagemaker_role_arn): + region = boto3.session.Session().region_name + input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region) + + input_s3 = sagemaker_session.upload_data( + path=os.path.join(DATA_DIR, 'sklearn_processing'), + bucket=sagemaker_session.default_bucket(), + key_prefix='integ-test-data/sklearn_processing/code' + ) + + output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing' + + inputs = [ + ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), + ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', + input_name='code'), + ] + + outputs = [ + ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', + output_name='train_data'), + ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', + output_name='test_data'), + ] + + # Build workflow definition + execution_input = ExecutionInput(schema={ + Field.ImageUri.value: str, + Field.InstanceCount.value: int, + Field.Entrypoint.value: str, + Field.Role.value: str, + Field.VolumeSizeInGB.value: int, + Field.MaxRuntimeInSeconds.value: int + }) + + job_name = generate_job_name() + processing_step = ProcessingStep('create_processing_job_step', + processor=sklearn_processor_fixture, + job_name=job_name, + inputs=inputs, + outputs=outputs, + container_arguments=['--train-test-split-ratio', '0.2'], + container_entrypoint=execution_input[Field.Entrypoint.value], + image_uri=execution_input[Field.ImageUri.value], + instance_count=execution_input[Field.InstanceCount.value], + role=execution_input[Field.Role.value], + volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], + max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value] + ) + workflow_graph = Chain([processing_step]) + + with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): + # Create workflow and check definition + workflow = create_workflow_and_check_definition( + workflow_graph=workflow_graph, + workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), + sfn_client=sfn_client, + sfn_role_arn=sfn_role_arn + ) + + execution_input = { + Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', + Field.InstanceCount.value: 1, + Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'], + Field.Role.value: sagemaker_role_arn, + Field.VolumeSizeInGB.value: 30, + Field.MaxRuntimeInSeconds.value: 500 + } + + # Execute workflow + execution = workflow.execute(inputs=execution_input) + execution_output = execution.get_output(wait=True) + + # Check workflow output + assert execution_output.get("ProcessingJobStatus") == "Completed" + + # Cleanup + state_machine_delete_wait(sfn_client, workflow.state_machine_arn) + # End of Cleanup diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index c643468..33da0ca 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -27,7 +27,9 @@ from unittest.mock import MagicMock, patch from stepfunctions.inputs import ExecutionInput, StepInput -from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep +from stepfunctions.steps.fields import Field +from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\ + ProcessingStep from stepfunctions.steps.sagemaker import tuning_config from tests.unit.utils import mock_boto_api_call @@ -962,3 +964,117 @@ def test_processing_step_creation(sklearn_processor): 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', 'End': True } + + +def test_processing_step_creation_with_placeholders(sklearn_processor): + execution_input = ExecutionInput(schema={ + Field.ImageUri.value: str, + Field.InstanceCount.value: int, + Field.Entrypoint.value: str, + Field.OutputKMSKey.value: str, + Field.Role.value: str, + Field.Env.value: str, + Field.VolumeSizeInGB.value: int, + Field.VolumeKMSKey.value: str, + Field.MaxRuntimeInSeconds.value: int, + Field.Tags.value: [{str: str}] + }) + + step_input = StepInput(schema={ + Field.InstanceType.value: str + }) + + inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] + outputs = [ + ProcessingOutput(source='/opt/ml/processing/output/train'), + ProcessingOutput(source='/opt/ml/processing/output/validation'), + ProcessingOutput(source='/opt/ml/processing/output/test') + ] + step = ProcessingStep( + 'Feature Transformation', + sklearn_processor, + 'MyProcessingJob', + container_entrypoint=execution_input[Field.Entrypoint.value], + kms_key_id=execution_input[Field.OutputKMSKey.value], + inputs=inputs, + outputs=outputs, + image_uri=execution_input[Field.ImageUri.value], + instance_count=execution_input[Field.InstanceCount.value], + instance_type=step_input[Field.InstanceType.value], + role=execution_input[Field.Role.value], + env=execution_input[Field.Env.value], + volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], + volume_kms_key=execution_input[Field.VolumeKMSKey.value], + max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value], + tags=execution_input[Field.Tags.value], + ) + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AppSpecification': { + 'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']", + 'ImageUri.$': "$$.Execution.Input['image_uri']" + }, + 'Environment.$': "$$.Execution.Input['env']", + 'ProcessingInputs': [ + { + 'InputName': None, + 'AppManaged': False, + 'S3Input': { + 'LocalPath': '/opt/ml/processing/input', + 'S3CompressionType': 'None', + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3InputMode': 'File', + 'S3Uri': 'dataset.csv' + } + } + ], + 'ProcessingOutputConfig': { + 'KmsKeyId.$': "$$.Execution.Input['output_kms_key']", + 'Outputs': [ + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/train', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/validation', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + }, + { + 'OutputName': None, + 'AppManaged': False, + 'S3Output': { + 'LocalPath': '/opt/ml/processing/output/test', + 'S3UploadMode': 'EndOfJob', + 'S3Uri': None + } + } + ] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount.$': "$$.Execution.Input['instance_count']", + 'InstanceType.$': "$['instance_type']", + 'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']", + 'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']" + } + }, + 'ProcessingJobName': 'MyProcessingJob', + 'RoleArn.$': "$$.Execution.Input['role']", + 'Tags.$': "$$.Execution.Input['tags']", + 'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"}, + }, + 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', + 'End': True + } From 4c63229c0798d43a51d0016cebb51dbb3ae9eed5 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 12:30:04 -0700 Subject: [PATCH 12/20] Added doc --- src/stepfunctions/steps/sagemaker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index d2f3740..e05742c 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -580,6 +580,10 @@ class ProcessingStep(SageMakerTask): """ Creates a Task State to execute a SageMaker Processing Job. + + The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically + in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type, + volume_size_in_gb, volume_kms_key, output_kms_key """ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, From 34bb281af81d9a7ed205ca73c278bb761b588161 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 12 Aug 2021 12:45:29 -0700 Subject: [PATCH 13/20] Removed contibuting changes(included in another pr) --- CONTRIBUTING.md | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6a0f342..99a078b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,15 +14,11 @@ information to effectively respond to your bug report or contribution. * [Contributing via Pull Requests (PRs)](#contributing-via-pull-requests-prs) * [Pulling Down the Code](#pulling-down-the-code) * [Running the Unit Tests](#running-the-unit-tests) - * [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) * [Running the Integration Tests](#running-the-integration-tests) * [Making and Testing Your Change](#making-and-testing-your-change) * [Committing Your Change](#committing-your-change) * [Sending a Pull Request](#sending-a-pull-request) * [Finding Contributions to Work On](#finding-contributions-to-work-on) -* [Setting Up Your Development Environment](#setting-up-your-development-environment) - * [Setting Up Your Environment for Debugging](#setting-up-your-environment-for-debugging) - * [PyCharm](#pycharm) * [Code of Conduct](#code-of-conduct) * [Security Issue Notifications](#security-issue-notifications) * [Licensing](#licensing) @@ -69,11 +65,6 @@ You can also run a single test with the following command: `tox -e py36 -- -s -v * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` * Example: `export IGNORE_COVERAGE=- ; tox -e py36 -- -s -vv tests/unit/test_sagemaker_steps.py::test_training_step_creation_with_model ; unset IGNORE_COVERAGE` -#### Running Unit Tests and Debugging in PyCharm -You can also run the unit tests with the following options: -* Right click on a test file in the Project tree and select `Run/Debug 'pytest' for ...` -* Right click on the test definition and select `Run/Debug 'pytest' for ...` -* Click on the green arrow next to test definition ### Running the Integration Tests @@ -177,33 +168,6 @@ Please remember to: Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws/aws-step-functions-data-science-sdk-python/labels/help%20wanted) issues is a great place to start. -## Setting Up Your Development Environment - -### Setting Up Your Environment for Debugging - -Setting up your IDE for debugging tests locally will save you a lot of time. -You might be able to `Run` and `Debug` the tests directly in your IDE with your default settings, but if it's not the case, -follow the steps described in this section. - -#### PyCharm -1. Set your Default test runner to `pytest` in _Preferences → Tools → Python Integrated Tools_ -1. If you are using `PyCharm Professional Edition`, go to _Preferences → Build, Execution, Deployment → Python Debugger_ and set the options with following values: - - | Option | Value | - |:------------------------------------------------------------ |:----------------------| - | Attach subprocess automatically while debugging | `Enabled` | - | Collect run-time types information for code insight | `Enabled` | - | Gevent compatible | `Disabled` | - | Drop into debugger on failed tests | `Enabled` | - | PyQt compatible | `Auto` | - | For Attach to Process show processes with names containing | `python` | - - This will allow you to break into all subprocesses of the process being debugged and preserve functions types while debugging. -1. Debug tests in PyCharm as per [Running Unit Tests and Debugging in PyCharm](#running-unit-tests-and-debugging-in-pycharm) - -_Note: This setup was tested and confirmed to work with -`PyCharm 2020.3.5 (Professional Edition)` and `PyCharm 2021.1.1 (Professional Edition)`_ - ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). From a098c61e52f12aa92466e06466aa7dee4c6cf154 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 16 Aug 2021 17:55:47 -0700 Subject: [PATCH 14/20] Merge sagemaker generated parameters with placeholder compatible parameters received in args --- src/stepfunctions/steps/constants.py | 30 ------ src/stepfunctions/steps/fields.py | 14 --- src/stepfunctions/steps/sagemaker.py | 148 +++++---------------------- src/stepfunctions/steps/utils.py | 22 ++++ tests/integ/test_sagemaker_steps.py | 54 ++++++---- tests/unit/test_sagemaker_steps.py | 63 ++++++++---- 6 files changed, 122 insertions(+), 209 deletions(-) delete mode 100644 src/stepfunctions/steps/constants.py diff --git a/src/stepfunctions/steps/constants.py b/src/stepfunctions/steps/constants.py deleted file mode 100644 index 1b308c8..0000000 --- a/src/stepfunctions/steps/constants.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. -from enum import Enum -from stepfunctions.steps.fields import Field - -# Path to SageMaker placeholder parameters -placeholder_paths = { - # Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html - 'ProcessingStep': { - Field.Role.value: ['RoleArn'], - Field.ImageUri.value: ['AppSpecification', 'ImageUri'], - Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'], - Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'], - Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], - Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'], - Field.Env.value: ['Environment'], - Field.Tags.value: ['Tags'], - } -} diff --git a/src/stepfunctions/steps/fields.py b/src/stepfunctions/steps/fields.py index fab4aa6..8eb102d 100644 --- a/src/stepfunctions/steps/fields.py +++ b/src/stepfunctions/steps/fields.py @@ -65,17 +65,3 @@ class Field(Enum): MaxAttempts = 'max_attempts' BackoffRate = 'backoff_rate' NextStep = 'next_step' - - # Sagemaker step fields - # Processing Step: Processor - Role = 'role' - ImageUri = 'image_uri' - InstanceCount = 'instance_count' - InstanceType = 'instance_type' - Entrypoint = 'entrypoint' - VolumeSizeInGB = 'volume_size_in_gb' - VolumeKMSKey = 'volume_kms_key' - OutputKMSKey = 'output_kms_key' - MaxRuntimeInSeconds = 'max_runtime_in_seconds' - Env = 'env' - Tags = 'tags' \ No newline at end of file diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index e05742c..03ee401 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -13,28 +13,24 @@ from __future__ import absolute_import import logging -import operator from enum import Enum -from functools import reduce -from stepfunctions.exceptions import InvalidPathToPlaceholderParameter from stepfunctions.inputs import Placeholder -from stepfunctions.steps.constants import placeholder_paths from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig -from sagemaker.processing import ProcessingJob logger = logging.getLogger('stepfunctions.sagemaker') SAGEMAKER_SERVICE_NAME = "sagemaker" + class SageMakerApi(Enum): CreateTrainingJob = "createTrainingJob" CreateTransformJob = "createTransformJob" @@ -46,104 +42,6 @@ class SageMakerApi(Enum): CreateProcessingJob = "createProcessingJob" -class SageMakerTask(Task): - - """ - Task State causes the interpreter to execute the work identified by the state’s `resource` field. - """ - - def __init__(self, state_id, step_type, tags, **kwargs): - """ - Args: - state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. - timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) - timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. - heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. - heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. - comment (str, optional): Human-readable comment or description. (default: None) - input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') - parameters (dict, optional): The value of this field becomes the effective input for the state. - result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') - output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') - """ - self._replace_sagemaker_placeholders(step_type, kwargs) - if tags: - self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type) - - super(SageMakerTask, self).__init__(state_id, **kwargs) - - - def allowed_fields(self): - sagemaker_fields = [ - # ProcessingStep: Processor - Field.Role, - Field.ImageUri, - Field.InstanceCount, - Field.InstanceType, - Field.Entrypoint, - Field.VolumeSizeInGB, - Field.VolumeKMSKey, - Field.OutputKMSKey, - Field.MaxRuntimeInSeconds, - Field.Env, - Field.Tags, - ] - - return super(SageMakerTask, self).allowed_fields() + sagemaker_fields - - - def _replace_sagemaker_placeholders(self, step_type, args): - # Fetch path from type - sagemaker_parameters = args[Field.Parameters.value] - paths = placeholder_paths.get(step_type) - treated_args = [] - - for arg_name, value in args.items(): - if arg_name in [Field.Parameters.value]: - continue - if arg_name in paths.keys(): - path = paths.get(arg_name) - if self._set_placeholder(sagemaker_parameters, path, value, arg_name): - treated_args.append(arg_name) - - SageMakerTask.remove_treated_args(treated_args, args) - - @staticmethod - def get_value_from_path(parameters, path): - value_from_path = reduce(operator.getitem, path, parameters) - return value_from_path - # return reduce(operator.getitem, path, parameters) - - @staticmethod - def _set_placeholder(parameters, path, value, arg_name): - is_set = False - try: - SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value - is_set = True - except KeyError as e: - message = f"Invalid path {path} for {arg_name}: {e}" - raise InvalidPathToPlaceholderParameter(message) - return is_set - - @staticmethod - def remove_treated_args(treated_args, args): - for treated_arg in treated_args: - try: - del args[treated_arg] - except KeyError as e: - pass - - def set_tags_config(self, tags, parameters, step_type): - if isinstance(tags, Placeholder): - # Replace with placeholder - path = placeholder_paths.get(step_type).get(Field.Tags.value) - if path: - self._set_placeholder(parameters, path, tags, Field.Tags.value) - else: - parameters['Tags'] = tags_dict_to_kv_list(tags) - - class TrainingStep(Task): """ @@ -576,7 +474,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta super(TuningStep, self).__init__(state_id, **kwargs) -class ProcessingStep(SageMakerTask): +class ProcessingStep(Task): """ Creates a Task State to execute a SageMaker Processing Job. @@ -588,7 +486,7 @@ class ProcessingStep(SageMakerTask): def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, - tags=None, max_runtime_in_seconds=None, **kwargs): + tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -600,16 +498,16 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for the processing job. These can be specified as either path strings or :class:`~sagemaker.processing.ProcessingOutput` objects (default: None). - experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None) - container_arguments ([str]): The arguments for a container used to run a processing job. - container_entrypoint ([str]): The entrypoint for a container used to run a processing job. - kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker + experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None) + container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job. + container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job. + kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. - max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job + parameters(dict, optional): The value of this field becomes the effective input for the state. """ if wait_for_completion: """ @@ -628,22 +526,26 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp SageMakerApi.CreateProcessingJob) if isinstance(job_name, str): - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) else: - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) + processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) if isinstance(job_name, Placeholder): - parameters['ProcessingJobName'] = job_name + processing_parameters['ProcessingJobName'] = job_name if experiment_config is not None: - parameters['ExperimentConfig'] = experiment_config - - if 'S3Operations' in parameters: - del parameters['S3Operations'] + processing_parameters['ExperimentConfig'] = experiment_config - if max_runtime_in_seconds: - parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds) - - kwargs[Field.Parameters.value] = parameters + if tags: + processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) + + if 'S3Operations' in processing_parameters: + del processing_parameters['S3Operations'] + + if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): + # Update processing_parameters with input parameters + merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters", + "Input Parameters") - super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs) + kwargs[Field.Parameters.value] = processing_parameters + super(ProcessingStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 6f44481..b45d107 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -14,6 +14,7 @@ import boto3 import logging +from stepfunctions.inputs import Placeholder logger = logging.getLogger('stepfunctions') @@ -45,3 +46,24 @@ def get_aws_partition(): return cur_partition return cur_partition + + +def merge_dicts(first, second, first_name, second_name): + """ + Merges first and second dictionaries into the first one. + Values in the first dict are updated with the values of the second one. + """ + if all(isinstance(d, dict) for d in [first, second]): + for key, value in second.items(): + if key in first: + if isinstance(first[key], dict) and isinstance(second[key], dict): + merge_dicts(first[key], second[key], first_name, second_name) + elif first[key] is value: + pass + else: + logger.info( + f"{first_name} property: <{key}> with value: <{first[key]}>" + f" will be overwritten with value provided in {second_name} : <{value}>") + first[key] = second[key] + else: + first[key] = second[key] diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 6400529..77bca39 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -31,7 +31,6 @@ from stepfunctions.inputs import ExecutionInput from stepfunctions.steps import Chain -from stepfunctions.steps.fields import Field from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow @@ -384,27 +383,41 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ # Build workflow definition execution_input = ExecutionInput(schema={ - Field.ImageUri.value: str, - Field.InstanceCount.value: int, - Field.Entrypoint.value: str, - Field.Role.value: str, - Field.VolumeSizeInGB.value: int, - Field.MaxRuntimeInSeconds.value: int + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'role': str, + 'volume_size_in_gb': int, + 'max_runtime_in_seconds': int, + 'container_arguments': [str], }) + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + } + } + job_name = generate_job_name() processing_step = ProcessingStep('create_processing_job_step', processor=sklearn_processor_fixture, job_name=job_name, inputs=inputs, outputs=outputs, - container_arguments=['--train-test-split-ratio', '0.2'], - container_entrypoint=execution_input[Field.Entrypoint.value], - image_uri=execution_input[Field.ImageUri.value], - instance_count=execution_input[Field.InstanceCount.value], - role=execution_input[Field.Role.value], - volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], - max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value] + container_arguments=execution_input['container_arguments'], + container_entrypoint=execution_input['entrypoint'], + parameters=parameters ) workflow_graph = Chain([processing_step]) @@ -418,12 +431,13 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ ) execution_input = { - Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', - Field.InstanceCount.value: 1, - Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'], - Field.Role.value: sagemaker_role_arn, - Field.VolumeSizeInGB.value: 30, - Field.MaxRuntimeInSeconds.value: 500 + 'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', + 'instance_count': 1, + 'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'], + 'role': sagemaker_role_arn, + 'volume_size_in_gb': 30, + 'max_runtime_in_seconds': 500, + 'container_arguments': ['--train-test-split-ratio', '0.2'] } # Execute workflow diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 33da0ca..664a498 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -968,22 +968,47 @@ def test_processing_step_creation(sklearn_processor): def test_processing_step_creation_with_placeholders(sklearn_processor): execution_input = ExecutionInput(schema={ - Field.ImageUri.value: str, - Field.InstanceCount.value: int, - Field.Entrypoint.value: str, - Field.OutputKMSKey.value: str, - Field.Role.value: str, - Field.Env.value: str, - Field.VolumeSizeInGB.value: int, - Field.VolumeKMSKey.value: str, - Field.MaxRuntimeInSeconds.value: int, - Field.Tags.value: [{str: str}] + 'image_uri': str, + 'instance_count': int, + 'entrypoint': str, + 'output_kms_key': str, + 'role': str, + 'env': str, + 'volume_size_in_gb': int, + 'volume_kms_key': str, + 'max_runtime_in_seconds': int, + 'tags': [{str: str}], + 'container_arguments': [str] }) step_input = StepInput(schema={ - Field.InstanceType.value: str + 'instance_type': str }) + parameters = { + 'AppSpecification': { + 'ContainerEntrypoint': execution_input['entrypoint'], + 'ImageUri': execution_input['image_uri'] + }, + 'Environment': execution_input['env'], + 'ProcessingOutputConfig': { + 'KmsKeyId': execution_input['output_kms_key'] + }, + 'ProcessingResources': { + 'ClusterConfig': { + 'InstanceCount': execution_input['instance_count'], + 'InstanceType': step_input['instance_type'], + 'VolumeKmsKeyId': execution_input['volume_kms_key'], + 'VolumeSizeInGB': execution_input['volume_size_in_gb'] + } + }, + 'RoleArn': execution_input['role'], + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] + }, + 'Tags': execution_input['tags'] + } + inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] outputs = [ ProcessingOutput(source='/opt/ml/processing/output/train'), @@ -994,24 +1019,18 @@ def test_processing_step_creation_with_placeholders(sklearn_processor): 'Feature Transformation', sklearn_processor, 'MyProcessingJob', - container_entrypoint=execution_input[Field.Entrypoint.value], - kms_key_id=execution_input[Field.OutputKMSKey.value], + container_entrypoint=execution_input['entrypoint'], + container_arguments=execution_input['container_arguments'], + kms_key_id=execution_input['output_kms_key'], inputs=inputs, outputs=outputs, - image_uri=execution_input[Field.ImageUri.value], - instance_count=execution_input[Field.InstanceCount.value], - instance_type=step_input[Field.InstanceType.value], - role=execution_input[Field.Role.value], - env=execution_input[Field.Env.value], - volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], - volume_kms_key=execution_input[Field.VolumeKMSKey.value], - max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value], - tags=execution_input[Field.Tags.value], + parameters=parameters ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'AppSpecification': { + 'ContainerArguments.$': "$$.Execution.Input['container_arguments']", 'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']", 'ImageUri.$': "$$.Execution.Input['image_uri']" }, From da99c92572f1529c500f0043459f8bbb269acef0 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 16 Aug 2021 18:10:46 -0700 Subject: [PATCH 15/20] Using == instead of is() --- src/stepfunctions/steps/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index b45d107..e181072 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -58,7 +58,7 @@ def merge_dicts(first, second, first_name, second_name): if key in first: if isinstance(first[key], dict) and isinstance(second[key], dict): merge_dicts(first[key], second[key], first_name, second_name) - elif first[key] is value: + elif first[key] == value: pass else: logger.info( From 37b2422bc600532be93b6444b6ade32dc14d9795 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Mon, 16 Aug 2021 18:22:28 -0700 Subject: [PATCH 16/20] Removed unused InvalidPathToPlaceholderParameter exception --- src/stepfunctions/exceptions.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/stepfunctions/exceptions.py b/src/stepfunctions/exceptions.py index 56be3ea..f7af53c 100644 --- a/src/stepfunctions/exceptions.py +++ b/src/stepfunctions/exceptions.py @@ -24,8 +24,3 @@ class MissingRequiredParameter(Exception): class DuplicateStatesInChain(Exception): pass - -class InvalidPathToPlaceholderParameter(Exception): - - def __init__(self, message): - super(InvalidPathToPlaceholderParameter, self).__init__(message) From fd640abfd152f8ee3d90c3e0a304e188e6480661 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 18 Aug 2021 15:51:38 -0700 Subject: [PATCH 17/20] Added doc and renamed args --- src/stepfunctions/steps/sagemaker.py | 9 ++++---- src/stepfunctions/steps/utils.py | 32 +++++++++++++++++----------- tests/unit/test_steps_utils.py | 9 +++++++- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 03ee401..95a31ab 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -478,10 +478,6 @@ class ProcessingStep(Task): """ Creates a Task State to execute a SageMaker Processing Job. - - The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically - in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type, - volume_size_in_gb, volume_kms_key, output_kms_key """ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, @@ -507,7 +503,10 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. - parameters(dict, optional): The value of this field becomes the effective input for the state. + parameters(dict, optional): The value of this field becomes the request for the + `CreateProcessingJob`_ created by the processing step. + All parameters fields are compatible with `Placeholders`_. + Any value defined in the parameters argument will overwrite the ones defined in the other arguments, including properties that were previously defined in the processor. """ if wait_for_completion: """ diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index e181072..4b81235 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -48,22 +48,28 @@ def get_aws_partition(): return cur_partition -def merge_dicts(first, second, first_name, second_name): +def merge_dicts(target, source, target_name, source_name): """ - Merges first and second dictionaries into the first one. - Values in the first dict are updated with the values of the second one. + Merges source dictionary into the target dictionary. + Values in the target dict are updated with the values of the source dict. + Args: + target (dict): Base dictionary into which source is merged + source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value + will overwrite target's value for the corresponding key + target_name (str): Name of target dictionary used for logging purposes + source_name (str): Name of source dictionary used for logging purposes """ - if all(isinstance(d, dict) for d in [first, second]): - for key, value in second.items(): - if key in first: - if isinstance(first[key], dict) and isinstance(second[key], dict): - merge_dicts(first[key], second[key], first_name, second_name) - elif first[key] == value: + if isinstance(target, dict) and isinstance(source, dict): + for key, value in source.items(): + if key in target: + if isinstance(target[key], dict) and isinstance(source[key], dict): + merge_dicts(target[key], source[key], target_name, source_name) + elif target[key] == value: pass else: logger.info( - f"{first_name} property: <{key}> with value: <{first[key]}>" - f" will be overwritten with value provided in {second_name} : <{value}>") - first[key] = second[key] + f"{target_name} property: <{key}> with value: <{target[key]}>" + f" will be overwritten with value provided in {source_name} : <{value}>") + target[key] = source[key] else: - first[key] = second[key] + target[key] = source[key] diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py index 6eb0885..2d99c0d 100644 --- a/tests/unit/test_steps_utils.py +++ b/tests/unit/test_steps_utils.py @@ -13,7 +13,7 @@ # Test if boto3 session can fetch correct aws partition info from test environment -from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.utils import get_aws_partition, merge_dicts from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn import boto3 from unittest.mock import patch @@ -51,3 +51,10 @@ def test_arn_builder_sagemaker_wait_completion(): IntegrationPattern.WaitForCompletion) assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" + +def test_merge_dicts(): + d1 = {'a': {'aa': 1, 'bb': 2, 'cc': 3}, 'b': 1} + d2 = {'a': {'bb': {'aaa': 1, 'bbb': 2}}, 'b': 2, 'c': 3} + + merge_dicts(d1, d2, 'd1', 'd2') + assert d1 == {'a': {'aa': 1, 'bb': {'aaa': 1, 'bbb': 2}, 'cc': 3}, 'b': 2, 'c': 3} From 1dfa0e32c8e2e1ed62c900d8922d1368a96ba159 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen <83104894+ca-nguyen@users.noreply.github.com> Date: Thu, 19 Aug 2021 16:05:20 -0700 Subject: [PATCH 18/20] Update src/stepfunctions/steps/sagemaker.py parameters description Co-authored-by: Adam Wong <55506708+wong-a@users.noreply.github.com> --- src/stepfunctions/steps/sagemaker.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 95a31ab..fd9685d 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -503,10 +503,9 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp The KmsKeyId is applied to all outputs. wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. - parameters(dict, optional): The value of this field becomes the request for the - `CreateProcessingJob`_ created by the processing step. - All parameters fields are compatible with `Placeholders`_. - Any value defined in the parameters argument will overwrite the ones defined in the other arguments, including properties that were previously defined in the processor. + parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob`_. + You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders`_. + """ if wait_for_completion: """ From 614378300342758323d64c1792e916787cb16407 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 19 Aug 2021 16:21:25 -0700 Subject: [PATCH 19/20] Removed dict name args to opt for more generic log message when overwriting dict values --- src/stepfunctions/steps/sagemaker.py | 3 +-- src/stepfunctions/steps/utils.py | 10 +++----- tests/unit/test_steps_utils.py | 38 ++++++++++++++++++++++++---- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index fd9685d..9530478 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -542,8 +542,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): # Update processing_parameters with input parameters - merge_dicts(processing_parameters, kwargs[Field.Parameters.value], "Processing Parameters", - "Input Parameters") + merge_dicts(processing_parameters, kwargs[Field.Parameters.value]) kwargs[Field.Parameters.value] = processing_parameters super(ProcessingStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 4b81235..5c6861f 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -48,7 +48,7 @@ def get_aws_partition(): return cur_partition -def merge_dicts(target, source, target_name, source_name): +def merge_dicts(target, source): """ Merges source dictionary into the target dictionary. Values in the target dict are updated with the values of the source dict. @@ -56,20 +56,18 @@ def merge_dicts(target, source, target_name, source_name): target (dict): Base dictionary into which source is merged source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value will overwrite target's value for the corresponding key - target_name (str): Name of target dictionary used for logging purposes - source_name (str): Name of source dictionary used for logging purposes """ if isinstance(target, dict) and isinstance(source, dict): for key, value in source.items(): if key in target: if isinstance(target[key], dict) and isinstance(source[key], dict): - merge_dicts(target[key], source[key], target_name, source_name) + merge_dicts(target[key], source[key]) elif target[key] == value: pass else: logger.info( - f"{target_name} property: <{key}> with value: <{target[key]}>" - f" will be overwritten with value provided in {source_name} : <{value}>") + f"Property: <{key}> with value: <{target[key]}>" + f" will be overwritten with provided value: <{value}>") target[key] = source[key] else: target[key] = source[key] diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py index 2d99c0d..7e06e37 100644 --- a/tests/unit/test_steps_utils.py +++ b/tests/unit/test_steps_utils.py @@ -53,8 +53,36 @@ def test_arn_builder_sagemaker_wait_completion(): def test_merge_dicts(): - d1 = {'a': {'aa': 1, 'bb': 2, 'cc': 3}, 'b': 1} - d2 = {'a': {'bb': {'aaa': 1, 'bbb': 2}}, 'b': 2, 'c': 3} - - merge_dicts(d1, d2, 'd1', 'd2') - assert d1 == {'a': {'aa': 1, 'bb': {'aaa': 1, 'bbb': 2}, 'cc': 3}, 'b': 2, 'c': 3} + d1 = { + 'a': { + 'aa': 1, + 'bb': 2, + 'cc': 3 + }, + 'b': 1 + } + + d2 = { + 'a': { + 'bb': { + 'aaa': 1, + 'bbb': 2 + } + }, + 'b': 2, + 'c': 3 + } + + merge_dicts(d1, d2) + assert d1 == { + 'a': { + 'aa': 1, + 'bb': { + 'aaa': 1, + 'bbb': 2 + }, + 'cc': 3 + }, + 'b': 2, + 'c': 3 + } From ebc5e225b11d94ed08d73a281ca46af6cc0ba559 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 19 Aug 2021 18:55:22 -0700 Subject: [PATCH 20/20] Using fstring in test --- src/stepfunctions/exceptions.py | 3 +-- tests/integ/test_sagemaker_steps.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/stepfunctions/exceptions.py b/src/stepfunctions/exceptions.py index f7af53c..7e9a4d7 100644 --- a/src/stepfunctions/exceptions.py +++ b/src/stepfunctions/exceptions.py @@ -22,5 +22,4 @@ class MissingRequiredParameter(Exception): class DuplicateStatesInChain(Exception): - pass - + pass \ No newline at end of file diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index 77bca39..f840302 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -358,7 +358,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, sagemaker_role_arn): region = boto3.session.Session().region_name - input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region) + input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv" input_s3 = sagemaker_session.upload_data( path=os.path.join(DATA_DIR, 'sklearn_processing'), @@ -366,7 +366,7 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ key_prefix='integ-test-data/sklearn_processing/code' ) - output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing' + output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing" inputs = [ ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), @@ -422,7 +422,6 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ workflow_graph = Chain([processing_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): - # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), @@ -449,4 +448,3 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn) - # End of Cleanup