From 09720b4a02431237c8eca4f1947f8f70cebb9692 Mon Sep 17 00:00:00 2001 From: Shunjia Ding Date: Fri, 3 Jan 2020 14:03:38 -0800 Subject: [PATCH] Delete S3Operations from ModelStep --- src/stepfunctions/steps/sagemaker.py | 3 +++ tests/unit/test_sagemaker_steps.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index d24d59f..cc12723 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -186,6 +186,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, **kwarg else: raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__)) + if 'S3Operations' in parameters: + del parameters['S3Operations'] + kwargs[Field.Parameters.value] = parameters kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 4c7238b..980a910 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -92,6 +92,7 @@ def tensorflow_estimator(): estimator.sagemaker_session = MagicMock() estimator.sagemaker_session.boto_region_name = 'us-east-1' + estimator.sagemaker_session._default_bucket = 'sagemaker' return estimator @@ -289,6 +290,38 @@ def test_get_expected_model(pca_estimator): 'End': True } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +def test_get_expected_model_with_framework_estimator(tensorflow_estimator): + training_step = TrainingStep('Training', + estimator=tensorflow_estimator, + data={'train': 's3://sagemaker/train'}, + job_name='tensorflow-job', + mini_batch_size=1024 + ) + expected_model = training_step.get_expected_model() + expected_model.entry_point = 'tf_train.py' + model_step = ModelStep('Create model', model=expected_model, model_name='tf-model') + assert model_step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'ExecutionRoleArn': EXECUTION_ROLE, + 'ModelName': 'tf-model', + 'PrimaryContainer': { + 'Environment': { + 'SAGEMAKER_PROGRAM': 'tf_train.py', + 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz', + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': 'us-east-1', + }, + 'Image': expected_model.image, + 'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" + } + }, + 'Resource': 'arn:aws:states:::sagemaker:createModel', + 'End': True + } + def test_model_step_creation(pca_model): step = ModelStep('Create model', model=pca_model, model_name='pca-model') assert step.to_dict() == {