-
Notifications
You must be signed in to change notification settings - Fork 90
fix: environment variables specified in Model or Estimator are not passed through to SageMaker ModelStep #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
597fe1d
ffbcf53
04cd813
ea3e482
f8008e8
6a4a826
dda1dfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,6 +157,8 @@ def get_expected_model(self, model_name=None): | |
model.name = model_name | ||
else: | ||
model.name = self.job_name | ||
if self.estimator.environment: | ||
model.env = self.estimator.environment | ||
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"] | ||
return model | ||
|
||
|
@@ -284,20 +286,10 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No | |
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. | ||
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
""" | ||
if isinstance(model, FrameworkModel): | ||
if isinstance(model, Model): | ||
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri) | ||
if model_name: | ||
parameters['ModelName'] = model_name | ||
elif isinstance(model, Model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this backwards compatible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this generated the same parameters as the old way of doing, but takes into account the env variables as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite. If I don't have enough context to know what effect that has, but instantiating a ModelStep with the same model will not produce the same parameters as before. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right - it shouldn't produce the same parameters if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a good point to consider why those params were initially omitted, but I think it makes sense to have consistency between parameters generated from FrameworkModel with the ones generated from Model With a FrameworkModel, params would also include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can we add the tests that would have caught this. our coverage might be limited
Do the commits where they were introduced provide that context? In any case, we need to preserve existing behaviour. Customers who upgrade without changing any of their code should be able to do that without unexpected mutations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, i will add test to ensure there is no regression and/or breaking changes
It was included in the initial commit, so no insight on why they were omitted |
||
parameters = { | ||
'ExecutionRoleArn': model.role, | ||
'ModelName': model_name or model.name, | ||
'PrimaryContainer': { | ||
'Environment': {}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the original change that this one superseded (#84) only modified this line. the rationale/need for this change isn't quite captured in the issue this is resolving. Although it's supported, we need to get into why we are making this change. As it modifies your control flow, also suggest adding tests for "all model types" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted for this change since I saw it was supported in sagemaker sdk since the issue was opened, but I agree that we should not introduce breaking changes - even more if the issue does not capture the need to. I'll revert the changes and go for the solution that was proposed in #84 and add test to validate that behaviour is the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We have tests that use a FrameworkModel (ex: test_training_step_creation_with_framework() and others that use a Model (ex: test_training_step_creation_with_model()), but none that pass both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I didn't mean in a single test (is that even possible). was just verifying that both parts of the control flow are being tested. I would assume that if they already existed, one of them would have broken with the attempt to introduce a breaking change. if that didn't happen, i think it surfaces a gap in testing and we should use this opportunity to plug it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test in the latest commit that confirms we are consistent |
||
'Image': model.image_uri, | ||
'ModelDataUrl': model.model_data | ||
} | ||
} | ||
else: | ||
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,37 @@ def pca_estimator(): | |
return pca | ||
|
||
|
||
@pytest.fixture | ||
def pca_estimator_with_env(): | ||
s3_output_location = 's3://sagemaker/models' | ||
|
||
pca = sagemaker.estimator.Estimator( | ||
PCA_IMAGE, | ||
role=EXECUTION_ROLE, | ||
instance_count=1, | ||
instance_type='ml.c4.xlarge', | ||
output_path=s3_output_location, | ||
environment={ | ||
'JobName': "job_name", | ||
'ModelName': "model_name" | ||
} | ||
) | ||
|
||
pca.set_hyperparameters( | ||
feature_dim=50000, | ||
num_components=10, | ||
subtract_mean=True, | ||
algorithm_mode='randomized', | ||
mini_batch_size=200 | ||
) | ||
|
||
pca.sagemaker_session = MagicMock() | ||
pca.sagemaker_session.boto_region_name = 'us-east-1' | ||
pca.sagemaker_session._default_bucket = 'sagemaker' | ||
|
||
return pca | ||
|
||
|
||
@pytest.fixture | ||
def pca_estimator_with_debug_hook(): | ||
s3_output_location = 's3://sagemaker/models' | ||
|
@@ -498,6 +529,63 @@ def test_training_step_creation_with_model(pca_estimator): | |
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_training_step_creation_with_model_with_env(pca_estimator_with_env): | ||
training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is using a new model with a defined env, vpc_config and image_config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (not blocking: but please follow this guidance for future contributions) You're following the existing conventions here, but let's try to structure the tests as documentation so they are easier to read in the future.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ACK |
||
model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) | ||
training_step.next(model_step) | ||
assert training_step.to_dict() == { | ||
'Type': 'Task', | ||
'Parameters': { | ||
'AlgorithmSpecification': { | ||
'TrainingImage': PCA_IMAGE, | ||
'TrainingInputMode': 'File' | ||
}, | ||
'OutputDataConfig': { | ||
'S3OutputPath': 's3://sagemaker/models' | ||
}, | ||
'StoppingCondition': { | ||
'MaxRuntimeInSeconds': 86400 | ||
}, | ||
'ResourceConfig': { | ||
'InstanceCount': 1, | ||
'InstanceType': 'ml.c4.xlarge', | ||
'VolumeSizeInGB': 30 | ||
}, | ||
'RoleArn': EXECUTION_ROLE, | ||
'HyperParameters': { | ||
'feature_dim': '50000', | ||
'num_components': '10', | ||
'subtract_mean': 'True', | ||
'algorithm_mode': 'randomized', | ||
'mini_batch_size': '200' | ||
}, | ||
'TrainingJobName': 'TrainingJob' | ||
}, | ||
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', | ||
'Next': 'Training - Save Model' | ||
} | ||
|
||
assert model_step.to_dict() == { | ||
'Type': 'Task', | ||
'Resource': 'arn:aws:states:::sagemaker:createModel', | ||
'Parameters': { | ||
'ExecutionRoleArn': EXECUTION_ROLE, | ||
'ModelName.$': "$['TrainingJobName']", | ||
'PrimaryContainer': { | ||
'Environment': { | ||
'JobName': 'job_name', | ||
'ModelName': 'model_name' | ||
}, | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'Image': PCA_IMAGE, | ||
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" | ||
} | ||
}, | ||
'End': True | ||
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_training_step_creation_with_framework(tensorflow_estimator): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CR description should also capture why we're doing this. It's not strictly related to the bug reported, but is solving an issue. Typically better to fix things separately, but when there are multiple fixes, they need to be captured in the commit summary, along with rationale and any testing that was performed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the description with the uncovered bug and what was done to fix it
I am open to move this fix to a separate PR - will go with what reviewers prefer since the review process has already begun
What do you think? @shivlaks @wong-a
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given where we currently are, i'm not strongly opinionated because it feels more practical to go with it. It feels like it might just be busy work splitting them up in the current state.
Generally and going forward, I think we should keep changes contained and specific to the bug they address / feature they introduce for a few reasons:
having said that, i would probably still split it up if I were the one doing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's already implemented and works, let's keep it in here, since it's a small enough change. Just update the PR description accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)