Skip to content

Commit 09cab9c

Browse files
authored
Delete S3Operations from ModelStep (#16)
1 parent 6c2e2d4 commit 09cab9c

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/stepfunctions/steps/sagemaker.py

+3
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, **kwarg
186186
else:
187187
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
188188

189+
if 'S3Operations' in parameters:
190+
del parameters['S3Operations']
191+
189192
kwargs[Field.Parameters.value] = parameters
190193
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
191194

tests/unit/test_sagemaker_steps.py

+33
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def tensorflow_estimator():
9292

9393
estimator.sagemaker_session = MagicMock()
9494
estimator.sagemaker_session.boto_region_name = 'us-east-1'
95+
estimator.sagemaker_session._default_bucket = 'sagemaker'
9596

9697
return estimator
9798

@@ -289,6 +290,38 @@ def test_get_expected_model(pca_estimator):
289290
'End': True
290291
}
291292

293+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
294+
def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
295+
training_step = TrainingStep('Training',
296+
estimator=tensorflow_estimator,
297+
data={'train': 's3://sagemaker/train'},
298+
job_name='tensorflow-job',
299+
mini_batch_size=1024
300+
)
301+
expected_model = training_step.get_expected_model()
302+
expected_model.entry_point = 'tf_train.py'
303+
model_step = ModelStep('Create model', model=expected_model, model_name='tf-model')
304+
assert model_step.to_dict() == {
305+
'Type': 'Task',
306+
'Parameters': {
307+
'ExecutionRoleArn': EXECUTION_ROLE,
308+
'ModelName': 'tf-model',
309+
'PrimaryContainer': {
310+
'Environment': {
311+
'SAGEMAKER_PROGRAM': 'tf_train.py',
312+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz',
313+
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
314+
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
315+
'SAGEMAKER_REGION': 'us-east-1',
316+
},
317+
'Image': expected_model.image,
318+
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
319+
}
320+
},
321+
'Resource': 'arn:aws:states:::sagemaker:createModel',
322+
'End': True
323+
}
324+
292325
def test_model_step_creation(pca_model):
293326
step = ModelStep('Create model', model=pca_model, model_name='pca-model')
294327
assert step.to_dict() == {

0 commit comments

Comments
 (0)