Skip to content

Add search for training job and fix slow tests #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 81 additions & 16 deletions src/smexperiments/api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,85 @@ def __init__(self, code=None, message=None, metric_index=None, **kwargs):
super(BatchPutMetricsError, self).__init__(code=code, message=message, metric_index=metric_index, **kwargs)


class TrainingJobSearchResult(_base_types.ApiObject):
"""Summary model of an Training Job search result.

Attributes:
training_job_name (str): The name of the training job.
training_job_arn (str): The Amazon Resource Name (ARN) of the training job.
tuning_job_arn (str): The Amazon Resource Name (ARN) of the associated.
hyperparameter tuning job if the training job was launched by a hyperparameter tuning job.
labeling_job_arn (str): The Amazon Resource Name (ARN) of the labeling job.
autoML_job_arn (str): The Amazon Resource Name (ARN) of the job.
model_artifacts (dict): Information about the Amazon S3 location that is configured for storing model artifacts.
training_job_status (str): The status of the training job
hyper_parameters (dict): Algorithm-specific parameters.
algorithm_specification (dict): Information about the algorithm used for training, and algorithm metadata.
input_data_config (dict): An array of Channel objects that describes each data input channel.
output_data_config (dict): The S3 path where model artifacts that you configured when creating the job are
stored. Amazon SageMaker creates subfolders for model artifacts.
resource_config (dict): Resources, including ML compute instances and ML storage volumes, that are configured
for model training.
debug_hook_config (dict): Configuration information for the debug hook parameters, collection configuration,
and storage paths.
debug_rule_config (dict): Information about the debug rule configuration.
"""

training_job_name = None
training_job_arn = None
tuning_job_arn = None
labeling_job_arn = None
autoML_job_arn = None
model_artifacts = None
training_job_status = None
hyper_parameters = None
algorithm_specification = None
input_data_config = None
output_data_config = None
resource_config = None
debug_hook_config = None
experiment_config = None
debug_rule_config = None
Comment on lines +299 to +313
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you construct this list of fields? what about some other fields like SecondaryStatus and FailureReason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the ones I think are most useful, but others should be covered by **kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems incomplete tho, we will get back all the fields in search regardless of what is in kwargs. i guess we can follow up later and add the remaining fields.


def __init__(
self,
training_job_arn=None,
training_job_name=None,
tuning_job_arn=None,
labeling_job_arn=None,
autoML_job_arn=None,
model_artifacts=None,
training_job_status=None,
hyper_parameters=None,
algorithm_specification=None,
input_data_config=None,
output_data_config=None,
resource_config=None,
debug_hook_config=None,
experiment_config=None,
debug_rule_config=None,
**kwargs
):
super(TrainingJobSearchResult, self).__init__(
training_job_arn=training_job_arn,
training_job_name=training_job_name,
tuning_job_arn=tuning_job_arn,
labeling_job_arn=labeling_job_arn,
autoML_job_arn=autoML_job_arn,
model_artifacts=model_artifacts,
training_job_status=training_job_status,
hyper_parameters=hyper_parameters,
algorithm_specification=algorithm_specification,
input_data_config=input_data_config,
output_data_config=output_data_config,
resource_config=resource_config,
debug_hook_config=debug_hook_config,
experiment_config=experiment_config,
debug_rule_config=debug_rule_config,
**kwargs
)


class ExperimentSearchResult(_base_types.ApiObject):
"""Summary model of an Experiment search result.

Expand Down Expand Up @@ -369,12 +448,6 @@ class TrialComponentSearchResult(_base_types.ApiObject):
display_name = None
source = None
status = None
start_time = None
end_time = None
creation_time = None
created_by = None
last_modified_time = None
last_modified_by = None
Comment on lines -374 to -377
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are in the response see DescribeExperiment.xml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I removed those since those will be covered by **kwargs

parameters = None
input_artifacts = None
output_artifacts = None
Expand All @@ -392,35 +465,27 @@ def __init__(
display_name=None,
source=None,
status=None,
creation_time=None,
created_by=None,
last_modified_time=None,
last_modified_by=None,
parameters=None,
input_artifacts=None,
output_artifacts=None,
metrics=None,
source_detail=None,
tags=None,
parents=None,
**kwargs
):
super(TrialComponentSearchResult, self).__init__(
trial_component_arn=trial_component_arn,
trial_component_name=trial_component_name,
display_name=display_name,
source=source,
status=status,
start_time=start_time,
end_time=end_time,
creation_time=creation_time,
created_by=created_by,
last_modified_by=last_modified_by,
last_modified_time=last_modified_time,
parameters=parameters,
input_artifacts=input_artifacts,
output_artifacts=output_artifacts,
metrics=metrics,
source_detail=source_detail,
tags=tags,
parents=parents,
**kwargs
)
49 changes: 49 additions & 0 deletions src/smexperiments/training_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Contains the SageMaker Training Job class."""
from smexperiments import _base_types, api_types


class TrainingJob(_base_types.Record):
@classmethod
def search(
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
):
"""
Search Training Job. Returns SearchResults in the account matching the search criteria.

Args:
search_expression: (dict, optional): A Boolean conditional statement. Resource objects
must satisfy this condition to be included in search results. You must provide at
least one subexpression, filter, or nested filter.
sort_by (str, optional): The name of the resource property used to sort the SearchResults.
The default is LastModifiedTime
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or
Descending . The default is Descending .
max_results (int, optional): The maximum number of results to return in a SearchResponse.
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
supplied, a default boto3 client will be used.

Returns:
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria.
"""
return super(TrainingJob, cls)._search(
search_resource="TrainingJob",
search_item_factory=api_types.TrainingJobSearchResult.from_boto,
search_expression=None if search_expression is None else search_expression.to_boto(),
sort_by=sort_by,
sort_order=sort_order,
max_results=max_results,
sagemaker_boto_client=sagemaker_boto_client,
)
6 changes: 1 addition & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,7 @@ def training_job_name(sagemaker_boto_client, training_role_arn, docker_image, tr
"DataSource": {"S3DataSource": {"S3Uri": training_s3_uri, "S3DataType": "S3Prefix"}},
}
],
AlgorithmSpecification={
"TrainingImage": docker_image,
"TrainingInputMode": "File",
"EnableSageMakerMetricsTimeSeries": True,
},
AlgorithmSpecification={"TrainingImage": docker_image, "TrainingInputMode": "File",},
RoleArn=training_role_arn,
ResourceConfig={"InstanceType": "ml.m5.large", "InstanceCount": 1, "VolumeSizeInGB": 10},
StoppingCondition={"MaxRuntimeInSeconds": 900},
Expand Down
3 changes: 1 addition & 2 deletions tests/integ/test_track_from_processing_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

@pytest.mark.slow
def test_track_from_processing_job(sagemaker_boto_client, processing_job_name):

get_job = lambda: sagemaker_boto_client.describe_processing_job(ProcessingJobName=processing_job_name)
processing_job = get_job()

source_arn = processing_job["ProcessingJobArn"]
# wait_for_job(processing_job_name, get_job, "ProcessingJobStatus")
wait_for_job(processing_job_name, get_job, "ProcessingJobStatus")

print(processing_job)
if "ProcessingStartTime" in processing_job:
Expand Down
3 changes: 1 addition & 2 deletions tests/integ/test_track_from_training_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done automatically, i created an issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand All @@ -21,7 +21,6 @@

@pytest.mark.slow
def test_track_from_training_job(sagemaker_boto_client, training_job_name):
training_job_name = "smexperiments-integ-eca5c064-3a64-433e-a30a-2963338d71d8"
get_job = lambda: sagemaker_boto_client.describe_training_job(TrainingJobName=training_job_name)
tj = get_job()
source_arn = tj["TrainingJobArn"]
Expand Down
46 changes: 46 additions & 0 deletions tests/integ/test_training_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest

from smexperiments.training_job import TrainingJob
from smexperiments.search_expression import SearchExpression, Filter, Operator
from tests.helpers import retry


@pytest.mark.slow
def test_search(sagemaker_boto_client, training_job_name, docker_image):
def validate():
training_job_searched = []
search_filter = Filter(name="TrainingJobName", operator=Operator.EQUALS, value=training_job_name)
search_expression = SearchExpression(filters=[search_filter])
for s in TrainingJob.search(
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client,
):
training_job_searched.append(s)

assert len(training_job_searched) == 1
assert training_job_searched[0].training_job_name == training_job_name
assert training_job_searched[0].input_data_config[0]["ChannelName"] == "train"
assert training_job_searched[0].algorithm_specification == {
"TrainingImage": docker_image,
"TrainingInputMode": "File",
}
assert training_job_searched[0].resource_config == {
"InstanceType": "ml.m5.large",
"InstanceCount": 1,
"VolumeSizeInGB": 10,
}
assert training_job_searched[0].stopping_condition == {"MaxRuntimeInSeconds": 900}
assert training_job_searched # sanity test

retry(validate)
55 changes: 55 additions & 0 deletions tests/unit/test_training_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest
import unittest.mock

from smexperiments import training_job, api_types


@pytest.fixture
def sagemaker_boto_client():
return unittest.mock.Mock()


def test_search(sagemaker_boto_client):
sagemaker_boto_client.search.return_value = {
"Results": [
{
"TrainingJob": {
"TrainingJobName": "training-1",
"TrainingJobArn": "arn::training-1",
"HyperParameters": {"learning_rate": "0.1"},
}
},
{
"TrainingJob": {
"TrainingJobName": "training-2",
"TrainingJobArn": "arn::training-2",
"HyperParameters": {"learning_rate": "0.2"},
}
},
]
}
expected = [
api_types.TrainingJobSearchResult(
training_job_name="training-1",
training_job_arn="arn::training-1",
hyper_parameters={"learning_rate": "0.1"},
),
api_types.TrainingJobSearchResult(
training_job_name="training-2",
training_job_arn="arn::training-2",
hyper_parameters={"learning_rate": "0.2"},
),
]
assert expected == list(training_job.TrainingJob.search(sagemaker_boto_client=sagemaker_boto_client))