@@ -92,6 +92,7 @@ def tensorflow_estimator():
92
92
93
93
estimator .sagemaker_session = MagicMock ()
94
94
estimator .sagemaker_session .boto_region_name = 'us-east-1'
95
+ estimator .sagemaker_session ._default_bucket = 'sagemaker'
95
96
96
97
return estimator
97
98
@@ -289,6 +290,38 @@ def test_get_expected_model(pca_estimator):
289
290
'End' : True
290
291
}
291
292
293
+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
294
+ def test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
295
+ training_step = TrainingStep ('Training' ,
296
+ estimator = tensorflow_estimator ,
297
+ data = {'train' : 's3://sagemaker/train' },
298
+ job_name = 'tensorflow-job' ,
299
+ mini_batch_size = 1024
300
+ )
301
+ expected_model = training_step .get_expected_model ()
302
+ expected_model .entry_point = 'tf_train.py'
303
+ model_step = ModelStep ('Create model' , model = expected_model , model_name = 'tf-model' )
304
+ assert model_step .to_dict () == {
305
+ 'Type' : 'Task' ,
306
+ 'Parameters' : {
307
+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
308
+ 'ModelName' : 'tf-model' ,
309
+ 'PrimaryContainer' : {
310
+ 'Environment' : {
311
+ 'SAGEMAKER_PROGRAM' : 'tf_train.py' ,
312
+ 'SAGEMAKER_SUBMIT_DIRECTORY' : 's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz' ,
313
+ 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS' : 'false' ,
314
+ 'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20' ,
315
+ 'SAGEMAKER_REGION' : 'us-east-1' ,
316
+ },
317
+ 'Image' : expected_model .image ,
318
+ 'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']"
319
+ }
320
+ },
321
+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
322
+ 'End' : True
323
+ }
324
+
292
325
def test_model_step_creation (pca_model ):
293
326
step = ModelStep ('Create model' , model = pca_model , model_name = 'pca-model' )
294
327
assert step .to_dict () == {
0 commit comments