Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/smexperiments/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
def list(
cls,
experiment_name=None,
trial_component_name=None,
created_before=None,
created_after=None,
sort_by=None,
Expand All @@ -136,6 +137,8 @@ def list(
Args:
experiment_name (str, optional): Name of the experiment. If specified, only trials in
the experiment will be returned.
trial_component_name (str, optional): Name of the trial component. If specified, only
trials with this trial component name will be returned.
created_before (datetime.datetime, optional): Return trials created before this instant.
created_after (datetime.datetime, optional): Return trials created after this instant.
sort_by (str, optional): Which property to sort results by. One of 'Name',
Expand All @@ -153,6 +156,7 @@ def list(
api_types.TrialSummary.from_boto,
"TrialSummaries",
experiment_name=experiment_name,
trial_component_name=trial_component_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this depend on a new boto version? please verify min boto version specified in dependencies has this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from https://github.com/boto/boto3/blob/2a4efd21345597189fbe3cccaa4b5593da504cb1/.changes/1.10.20.json looks like last update on sagemaker client is happening in 1.10.20 version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, updated

created_before=created_before,
created_after=created_after,
sort_by=sort_by,
Expand Down
17 changes: 17 additions & 0 deletions tests/integ/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def test_list(trials, sagemaker_boto_client):
assert trial_names_listed # sanity test


def test_list_with_trial_component(trials, trial_component_obj, sagemaker_boto_client):
trial_with_component = trials[0]
trial_with_component.add_trial_component(trial_component_obj)

trial_listed = [
s.trial_name
for s in trial.Trial.list(
trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client
)
]
assert len(trial_listed) == 1
assert trial_with_component.trial_name == trial_listed[0]
# clean up
trial_with_component.remove_trial_component(trial_component_obj)
assert trial_listed


def test_list_sort(trials, sagemaker_boto_client):
slack = datetime.timedelta(minutes=1)
now = datetime.datetime.now(datetime.timezone.utc)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_list_trials_with_experiment_name(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_trials.assert_called_with(ExperimentName="foo")


def test_list_trials_with_trial_component_name(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_trials.return_value = {
"TrialSummaries": [
{"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
{"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
]
}
expected = [
api_types.TrialSummary(trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj),
api_types.TrialSummary(trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj),
]
assert expected == list(
trial.Trial.list(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
)
sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="tc-foo")


def test_delete(sagemaker_boto_client):
obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
sagemaker_boto_client.delete_trial.return_value = {}
Expand Down