Skip to content

Commit 597fe1d

Browse files
committed
Use model_config to generate CreateModelStep parameters for models without instance type
1 parent f8bbfaf commit 597fe1d

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def get_expected_model(self, model_name=None):
157157
model.name = model_name
158158
else:
159159
model.name = self.job_name
160+
if self.estimator.environment is not None:
161+
model.env = self.estimator.environment
160162
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
161163
return model
162164

@@ -284,20 +286,10 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
284286
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
285287
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
286288
"""
287-
if isinstance(model, FrameworkModel):
289+
if isinstance(model, Model):
288290
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
289291
if model_name:
290292
parameters['ModelName'] = model_name
291-
elif isinstance(model, Model):
292-
parameters = {
293-
'ExecutionRoleArn': model.role,
294-
'ModelName': model_name or model.name,
295-
'PrimaryContainer': {
296-
'Environment': {},
297-
'Image': model.image_uri,
298-
'ModelDataUrl': model.model_data
299-
}
300-
}
301293
else:
302294
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
303295

tests/unit/test_sagemaker_steps.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def pca_estimator():
5050
role=EXECUTION_ROLE,
5151
instance_count=1,
5252
instance_type='ml.c4.xlarge',
53-
output_path=s3_output_location
53+
output_path=s3_output_location,
54+
environment={
55+
'JobName': "job_name",
56+
'ModelName': "model_name"
57+
}
5458
)
5559

5660
pca.set_hyperparameters(
@@ -489,7 +493,10 @@ def test_training_step_creation_with_model(pca_estimator):
489493
'ExecutionRoleArn': EXECUTION_ROLE,
490494
'ModelName.$': "$['TrainingJobName']",
491495
'PrimaryContainer': {
492-
'Environment': {},
496+
'Environment': {
497+
'JobName': 'job_name',
498+
'ModelName': 'model_name'
499+
},
493500
'Image': PCA_IMAGE,
494501
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
495502
}
@@ -757,7 +764,10 @@ def test_get_expected_model(pca_estimator):
757764
'ExecutionRoleArn': EXECUTION_ROLE,
758765
'ModelName': 'pca-model',
759766
'PrimaryContainer': {
760-
'Environment': {},
767+
'Environment': {
768+
'JobName': 'job_name',
769+
'ModelName': 'model_name'
770+
},
761771
'Image': expected_model.image_uri,
762772
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
763773
}

0 commit comments

Comments
 (0)