Skip to content

Commit 01e18c3

Browse files
ca-nguyenshivlakswong-a
authored
feat: Support placeholders for processing step (#155)
* Support placeholders for processor parameters in processing step * Merge sagemaker generated parameters with placeholder compatible parameters received in args Co-authored-by: Shiv Lakshminarayan <[email protected]> Co-authored-by: Adam Wong <[email protected]>
1 parent 6b62bf7 commit 01e18c3

File tree

6 files changed

+321
-21
lines changed

6 files changed

+321
-21
lines changed

src/stepfunctions/steps/fields.py

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class Field(Enum):
5959
HeartbeatSeconds = 'heartbeat_seconds'
6060
HeartbeatSecondsPath = 'heartbeat_seconds_path'
6161

62-
6362
# Retry and catch fields
6463
ErrorEquals = 'error_equals'
6564
IntervalSeconds = 'interval_seconds'

src/stepfunctions/steps/sagemaker.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from stepfunctions.inputs import Placeholder
2020
from stepfunctions.steps.states import Task
2121
from stepfunctions.steps.fields import Field
22-
from stepfunctions.steps.utils import tags_dict_to_kv_list
22+
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
2323
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
2424

2525
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
@@ -30,6 +30,7 @@
3030

3131
SAGEMAKER_SERVICE_NAME = "sagemaker"
3232

33+
3334
class SageMakerApi(Enum):
3435
CreateTrainingJob = "createTrainingJob"
3536
CreateTransformJob = "createTransformJob"
@@ -479,7 +480,9 @@ class ProcessingStep(Task):
479480
Creates a Task State to execute a SageMaker Processing Job.
480481
"""
481482

482-
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
483+
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
484+
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
485+
tags=None, **kwargs):
483486
"""
484487
Args:
485488
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -491,15 +494,18 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
491494
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
492495
the processing job. These can be specified as either path strings or
493496
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
494-
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
495-
container_arguments ([str]): The arguments for a container used to run a processing job.
496-
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
497-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
497+
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
498+
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
499+
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
500+
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
498501
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
499502
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
500503
The KmsKeyId is applied to all outputs.
501504
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
502-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
505+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
506+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
507+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
508+
503509
"""
504510
if wait_for_completion:
505511
"""
@@ -518,22 +524,25 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
518524
SageMakerApi.CreateProcessingJob)
519525

520526
if isinstance(job_name, str):
521-
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
527+
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
522528
else:
523-
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
529+
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
524530

525531
if isinstance(job_name, Placeholder):
526-
parameters['ProcessingJobName'] = job_name
532+
processing_parameters['ProcessingJobName'] = job_name
527533

528534
if experiment_config is not None:
529-
parameters['ExperimentConfig'] = experiment_config
530-
535+
processing_parameters['ExperimentConfig'] = experiment_config
536+
531537
if tags:
532-
parameters['Tags'] = tags_dict_to_kv_list(tags)
533-
534-
if 'S3Operations' in parameters:
535-
del parameters['S3Operations']
536-
537-
kwargs[Field.Parameters.value] = parameters
538+
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
539+
540+
if 'S3Operations' in processing_parameters:
541+
del processing_parameters['S3Operations']
542+
543+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
544+
# Update processing_parameters with input parameters
545+
merge_dicts(processing_parameters, kwargs[Field.Parameters.value])
538546

547+
kwargs[Field.Parameters.value] = processing_parameters
539548
super(ProcessingStep, self).__init__(state_id, **kwargs)

src/stepfunctions/steps/utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import boto3
1616
import logging
17+
from stepfunctions.inputs import Placeholder
1718

1819
logger = logging.getLogger('stepfunctions')
1920

@@ -45,3 +46,28 @@ def get_aws_partition():
4546
return cur_partition
4647

4748
return cur_partition
49+
50+
51+
def merge_dicts(target, source):
52+
"""
53+
Merges source dictionary into the target dictionary.
54+
Values in the target dict are updated with the values of the source dict.
55+
Args:
56+
target (dict): Base dictionary into which source is merged
57+
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value
58+
will overwrite target's value for the corresponding key
59+
"""
60+
if isinstance(target, dict) and isinstance(source, dict):
61+
for key, value in source.items():
62+
if key in target:
63+
if isinstance(target[key], dict) and isinstance(source[key], dict):
64+
merge_dicts(target[key], source[key])
65+
elif target[key] == value:
66+
pass
67+
else:
68+
logger.info(
69+
f"Property: <{key}> with value: <{target[key]}>"
70+
f" will be overwritten with provided value: <{value}>")
71+
target[key] = source[key]
72+
else:
73+
target[key] = source[key]

tests/integ/test_sagemaker_steps.py

+96
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.tuner import HyperparameterTuner
3030
from sagemaker.processing import ProcessingInput, ProcessingOutput
3131

32+
from stepfunctions.inputs import ExecutionInput
3233
from stepfunctions.steps import Chain
3334
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
3435
from stepfunctions.workflow import Workflow
@@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
352353
# Cleanup
353354
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
354355
# End of Cleanup
356+
357+
358+
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
359+
sagemaker_role_arn):
360+
region = boto3.session.Session().region_name
361+
input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv"
362+
363+
input_s3 = sagemaker_session.upload_data(
364+
path=os.path.join(DATA_DIR, 'sklearn_processing'),
365+
bucket=sagemaker_session.default_bucket(),
366+
key_prefix='integ-test-data/sklearn_processing/code'
367+
)
368+
369+
output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing"
370+
371+
inputs = [
372+
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
373+
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
374+
input_name='code'),
375+
]
376+
377+
outputs = [
378+
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
379+
output_name='train_data'),
380+
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
381+
output_name='test_data'),
382+
]
383+
384+
# Build workflow definition
385+
execution_input = ExecutionInput(schema={
386+
'image_uri': str,
387+
'instance_count': int,
388+
'entrypoint': str,
389+
'role': str,
390+
'volume_size_in_gb': int,
391+
'max_runtime_in_seconds': int,
392+
'container_arguments': [str],
393+
})
394+
395+
parameters = {
396+
'AppSpecification': {
397+
'ContainerEntrypoint': execution_input['entrypoint'],
398+
'ImageUri': execution_input['image_uri']
399+
},
400+
'ProcessingResources': {
401+
'ClusterConfig': {
402+
'InstanceCount': execution_input['instance_count'],
403+
'VolumeSizeInGB': execution_input['volume_size_in_gb']
404+
}
405+
},
406+
'RoleArn': execution_input['role'],
407+
'StoppingCondition': {
408+
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
409+
}
410+
}
411+
412+
job_name = generate_job_name()
413+
processing_step = ProcessingStep('create_processing_job_step',
414+
processor=sklearn_processor_fixture,
415+
job_name=job_name,
416+
inputs=inputs,
417+
outputs=outputs,
418+
container_arguments=execution_input['container_arguments'],
419+
container_entrypoint=execution_input['entrypoint'],
420+
parameters=parameters
421+
)
422+
workflow_graph = Chain([processing_step])
423+
424+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
425+
workflow = create_workflow_and_check_definition(
426+
workflow_graph=workflow_graph,
427+
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
428+
sfn_client=sfn_client,
429+
sfn_role_arn=sfn_role_arn
430+
)
431+
432+
execution_input = {
433+
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
434+
'instance_count': 1,
435+
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
436+
'role': sagemaker_role_arn,
437+
'volume_size_in_gb': 30,
438+
'max_runtime_in_seconds': 500,
439+
'container_arguments': ['--train-test-split-ratio', '0.2']
440+
}
441+
442+
# Execute workflow
443+
execution = workflow.execute(inputs=execution_input)
444+
execution_output = execution.get_output(wait=True)
445+
446+
# Check workflow output
447+
assert execution_output.get("ProcessingJobStatus") == "Completed"
448+
449+
# Cleanup
450+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)

tests/unit/test_sagemaker_steps.py

+136-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
from unittest.mock import MagicMock, patch
2929
from stepfunctions.inputs import ExecutionInput, StepInput
30-
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
30+
from stepfunctions.steps.fields import Field
31+
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
32+
ProcessingStep
3133
from stepfunctions.steps.sagemaker import tuning_config
3234

3335
from tests.unit.utils import mock_boto_api_call
@@ -962,3 +964,136 @@ def test_processing_step_creation(sklearn_processor):
962964
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
963965
'End': True
964966
}
967+
968+
969+
def test_processing_step_creation_with_placeholders(sklearn_processor):
970+
execution_input = ExecutionInput(schema={
971+
'image_uri': str,
972+
'instance_count': int,
973+
'entrypoint': str,
974+
'output_kms_key': str,
975+
'role': str,
976+
'env': str,
977+
'volume_size_in_gb': int,
978+
'volume_kms_key': str,
979+
'max_runtime_in_seconds': int,
980+
'tags': [{str: str}],
981+
'container_arguments': [str]
982+
})
983+
984+
step_input = StepInput(schema={
985+
'instance_type': str
986+
})
987+
988+
parameters = {
989+
'AppSpecification': {
990+
'ContainerEntrypoint': execution_input['entrypoint'],
991+
'ImageUri': execution_input['image_uri']
992+
},
993+
'Environment': execution_input['env'],
994+
'ProcessingOutputConfig': {
995+
'KmsKeyId': execution_input['output_kms_key']
996+
},
997+
'ProcessingResources': {
998+
'ClusterConfig': {
999+
'InstanceCount': execution_input['instance_count'],
1000+
'InstanceType': step_input['instance_type'],
1001+
'VolumeKmsKeyId': execution_input['volume_kms_key'],
1002+
'VolumeSizeInGB': execution_input['volume_size_in_gb']
1003+
}
1004+
},
1005+
'RoleArn': execution_input['role'],
1006+
'StoppingCondition': {
1007+
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
1008+
},
1009+
'Tags': execution_input['tags']
1010+
}
1011+
1012+
inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')]
1013+
outputs = [
1014+
ProcessingOutput(source='/opt/ml/processing/output/train'),
1015+
ProcessingOutput(source='/opt/ml/processing/output/validation'),
1016+
ProcessingOutput(source='/opt/ml/processing/output/test')
1017+
]
1018+
step = ProcessingStep(
1019+
'Feature Transformation',
1020+
sklearn_processor,
1021+
'MyProcessingJob',
1022+
container_entrypoint=execution_input['entrypoint'],
1023+
container_arguments=execution_input['container_arguments'],
1024+
kms_key_id=execution_input['output_kms_key'],
1025+
inputs=inputs,
1026+
outputs=outputs,
1027+
parameters=parameters
1028+
)
1029+
assert step.to_dict() == {
1030+
'Type': 'Task',
1031+
'Parameters': {
1032+
'AppSpecification': {
1033+
'ContainerArguments.$': "$$.Execution.Input['container_arguments']",
1034+
'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']",
1035+
'ImageUri.$': "$$.Execution.Input['image_uri']"
1036+
},
1037+
'Environment.$': "$$.Execution.Input['env']",
1038+
'ProcessingInputs': [
1039+
{
1040+
'InputName': None,
1041+
'AppManaged': False,
1042+
'S3Input': {
1043+
'LocalPath': '/opt/ml/processing/input',
1044+
'S3CompressionType': 'None',
1045+
'S3DataDistributionType': 'FullyReplicated',
1046+
'S3DataType': 'S3Prefix',
1047+
'S3InputMode': 'File',
1048+
'S3Uri': 'dataset.csv'
1049+
}
1050+
}
1051+
],
1052+
'ProcessingOutputConfig': {
1053+
'KmsKeyId.$': "$$.Execution.Input['output_kms_key']",
1054+
'Outputs': [
1055+
{
1056+
'OutputName': None,
1057+
'AppManaged': False,
1058+
'S3Output': {
1059+
'LocalPath': '/opt/ml/processing/output/train',
1060+
'S3UploadMode': 'EndOfJob',
1061+
'S3Uri': None
1062+
}
1063+
},
1064+
{
1065+
'OutputName': None,
1066+
'AppManaged': False,
1067+
'S3Output': {
1068+
'LocalPath': '/opt/ml/processing/output/validation',
1069+
'S3UploadMode': 'EndOfJob',
1070+
'S3Uri': None
1071+
}
1072+
},
1073+
{
1074+
'OutputName': None,
1075+
'AppManaged': False,
1076+
'S3Output': {
1077+
'LocalPath': '/opt/ml/processing/output/test',
1078+
'S3UploadMode': 'EndOfJob',
1079+
'S3Uri': None
1080+
}
1081+
}
1082+
]
1083+
},
1084+
'ProcessingResources': {
1085+
'ClusterConfig': {
1086+
'InstanceCount.$': "$$.Execution.Input['instance_count']",
1087+
'InstanceType.$': "$['instance_type']",
1088+
'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']",
1089+
'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']"
1090+
}
1091+
},
1092+
'ProcessingJobName': 'MyProcessingJob',
1093+
'RoleArn.$': "$$.Execution.Input['role']",
1094+
'Tags.$': "$$.Execution.Input['tags']",
1095+
'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"},
1096+
},
1097+
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
1098+
'End': True
1099+
}

0 commit comments

Comments
 (0)