Skip to content

Commit 9af182d

Browse files
authored
Merge pull request #84 from yzhu0/modifyTCUpdateParam
Enable Removal of Parameters and Artifacts under TrialComponent update
2 parents a5e8f78 + 0ee91c9 commit 9af182d

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

src/smexperiments/trial_component.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class TrialComponent(_base_types.Record):
4040
input_artiacts (dict): Dictionary of input artifacts.
4141
output_artiacts (dict): Dictionary of output artifacts.
4242
metrics (obj): Aggregated metrics for the job.
43+
parameters_to_remove (list): The hyperparameters to remove from the component.
44+
input_artifacts_to_remove (list): The input artifacts to remove from the component.
45+
output_artifacts_to_remove (list): The output artifacts to remove from the component.
4346
"""
4447

4548
trial_component_name = None
@@ -57,6 +60,9 @@ class TrialComponent(_base_types.Record):
5760
input_artifacts = None
5861
output_artifacts = None
5962
metrics = None
63+
parameters_to_remove = None
64+
input_artifacts_to_remove = None
65+
output_artifacts_to_remove = None
6066

6167
_boto_load_method = "describe_trial_component"
6268
_boto_create_method = "create_trial_component"
@@ -81,6 +87,9 @@ class TrialComponent(_base_types.Record):
8187
"parameters",
8288
"input_artifacts",
8389
"output_artifacts",
90+
"parameters_to_remove",
91+
"input_artifacts_to_remove",
92+
"output_artifacts_to_remove",
8493
]
8594
_boto_delete_members = ["trial_component_name"]
8695

tests/integ/test_trial_component.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,17 @@ def test_save(trial_component_obj, sagemaker_boto_client):
3030
trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc)
3131
trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1}
3232
trial_component_obj.input_artifacts = {
33-
"snizz": api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain")
33+
"snizz": api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"),
34+
"snizz1": api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"),
3435
}
3536
trial_component_obj.output_artifacts = {
36-
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow")
37+
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"),
38+
"fly2": api_types.TrialComponentArtifact(value="s3:/sky/far2", media_type="away/tomorrow2"),
3739
}
40+
trial_component_obj.parameters_to_remove = ["foo"]
41+
trial_component_obj.input_artifacts_to_remove = ["snizz"]
42+
trial_component_obj.output_artifacts_to_remove = ["fly2"]
43+
3844
trial_component_obj.save()
3945

4046
loaded = trial_component.TrialComponent.load(
@@ -47,9 +53,13 @@ def test_save(trial_component_obj, sagemaker_boto_client):
4753
assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1)
4854
assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1)
4955

50-
assert trial_component_obj.parameters == loaded.parameters
51-
assert trial_component_obj.input_artifacts == loaded.input_artifacts
52-
assert trial_component_obj.output_artifacts == loaded.output_artifacts
56+
assert loaded.parameters == {"whizz": 100.1}
57+
assert loaded.input_artifacts == {
58+
"snizz1": api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2")
59+
}
60+
assert loaded.output_artifacts == {
61+
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow")
62+
}
5363

5464

5565
def test_load(trial_component_obj, sagemaker_boto_client):

tests/unit/test_trial_component.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,24 @@ def test_search(sagemaker_boto_client):
213213

214214

215215
def test_save(sagemaker_boto_client):
216-
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
216+
obj = trial_component.TrialComponent(
217+
sagemaker_boto_client,
218+
trial_component_name="foo",
219+
display_name="bar",
220+
parameters_to_remove=["E"],
221+
input_artifacts_to_remove=["F"],
222+
output_artifacts_to_remove=["G"],
223+
)
217224
sagemaker_boto_client.update_trial_component.return_value = {}
218225
obj.save()
219-
sagemaker_boto_client.update_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar")
226+
227+
sagemaker_boto_client.update_trial_component.assert_called_with(
228+
TrialComponentName="foo",
229+
DisplayName="bar",
230+
ParametersToRemove=["E"],
231+
InputArtifactsToRemove=["F"],
232+
OutputArtifactsToRemove=["G"],
233+
)
220234

221235

222236
def test_delete(sagemaker_boto_client):

0 commit comments

Comments
 (0)