-
Notifications
You must be signed in to change notification settings - Fork 36
Add search for training job and fix slow tests #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,6 +272,85 @@ def __init__(self, code=None, message=None, metric_index=None, **kwargs): | |
super(BatchPutMetricsError, self).__init__(code=code, message=message, metric_index=metric_index, **kwargs) | ||
|
||
|
||
class TrainingJobSearchResult(_base_types.ApiObject): | ||
"""Summary model of an Training Job search result. | ||
|
||
Attributes: | ||
training_job_name (str): The name of the training job. | ||
training_job_arn (str): The Amazon Resource Name (ARN) of the training job. | ||
tuning_job_arn (str): The Amazon Resource Name (ARN) of the associated. | ||
hyperparameter tuning job if the training job was launched by a hyperparameter tuning job. | ||
labeling_job_arn (str): The Amazon Resource Name (ARN) of the labeling job. | ||
autoML_job_arn (str): The Amazon Resource Name (ARN) of the job. | ||
model_artifacts (dict): Information about the Amazon S3 location that is configured for storing model artifacts. | ||
training_job_status (str): The status of the training job | ||
hyper_parameters (dict): Algorithm-specific parameters. | ||
algorithm_specification (dict): Information about the algorithm used for training, and algorithm metadata. | ||
input_data_config (dict): An array of Channel objects that describes each data input channel. | ||
output_data_config (dict): The S3 path where model artifacts that you configured when creating the job are | ||
stored. Amazon SageMaker creates subfolders for model artifacts. | ||
resource_config (dict): Resources, including ML compute instances and ML storage volumes, that are configured | ||
for model training. | ||
debug_hook_config (dict): Configuration information for the debug hook parameters, collection configuration, | ||
and storage paths. | ||
debug_rule_config (dict): Information about the debug rule configuration. | ||
""" | ||
|
||
training_job_name = None | ||
training_job_arn = None | ||
tuning_job_arn = None | ||
labeling_job_arn = None | ||
autoML_job_arn = None | ||
model_artifacts = None | ||
training_job_status = None | ||
hyper_parameters = None | ||
algorithm_specification = None | ||
input_data_config = None | ||
output_data_config = None | ||
resource_config = None | ||
debug_hook_config = None | ||
experiment_config = None | ||
debug_rule_config = None | ||
|
||
def __init__( | ||
self, | ||
training_job_arn=None, | ||
training_job_name=None, | ||
tuning_job_arn=None, | ||
labeling_job_arn=None, | ||
autoML_job_arn=None, | ||
model_artifacts=None, | ||
training_job_status=None, | ||
hyper_parameters=None, | ||
algorithm_specification=None, | ||
input_data_config=None, | ||
output_data_config=None, | ||
resource_config=None, | ||
debug_hook_config=None, | ||
experiment_config=None, | ||
debug_rule_config=None, | ||
**kwargs | ||
): | ||
super(TrainingJobSearchResult, self).__init__( | ||
training_job_arn=training_job_arn, | ||
training_job_name=training_job_name, | ||
tuning_job_arn=tuning_job_arn, | ||
labeling_job_arn=labeling_job_arn, | ||
autoML_job_arn=autoML_job_arn, | ||
model_artifacts=model_artifacts, | ||
training_job_status=training_job_status, | ||
hyper_parameters=hyper_parameters, | ||
algorithm_specification=algorithm_specification, | ||
input_data_config=input_data_config, | ||
output_data_config=output_data_config, | ||
resource_config=resource_config, | ||
debug_hook_config=debug_hook_config, | ||
experiment_config=experiment_config, | ||
debug_rule_config=debug_rule_config, | ||
**kwargs | ||
) | ||
|
||
|
||
class ExperimentSearchResult(_base_types.ApiObject): | ||
"""Summary model of an Experiment search result. | ||
|
||
|
@@ -369,12 +448,6 @@ class TrialComponentSearchResult(_base_types.ApiObject): | |
display_name = None | ||
source = None | ||
status = None | ||
start_time = None | ||
end_time = None | ||
creation_time = None | ||
created_by = None | ||
last_modified_time = None | ||
last_modified_by = None | ||
Comment on lines
-374
to
-377
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are in the response see DescribeExperiment.xml There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I removed those since those will be covered by **kwargs |
||
parameters = None | ||
input_artifacts = None | ||
output_artifacts = None | ||
|
@@ -392,35 +465,27 @@ def __init__( | |
display_name=None, | ||
source=None, | ||
status=None, | ||
creation_time=None, | ||
created_by=None, | ||
last_modified_time=None, | ||
last_modified_by=None, | ||
parameters=None, | ||
input_artifacts=None, | ||
output_artifacts=None, | ||
metrics=None, | ||
source_detail=None, | ||
tags=None, | ||
parents=None, | ||
**kwargs | ||
): | ||
super(TrialComponentSearchResult, self).__init__( | ||
trial_component_arn=trial_component_arn, | ||
trial_component_name=trial_component_name, | ||
display_name=display_name, | ||
source=source, | ||
status=status, | ||
start_time=start_time, | ||
end_time=end_time, | ||
creation_time=creation_time, | ||
created_by=created_by, | ||
last_modified_by=last_modified_by, | ||
last_modified_time=last_modified_time, | ||
parameters=parameters, | ||
input_artifacts=input_artifacts, | ||
output_artifacts=output_artifacts, | ||
metrics=metrics, | ||
source_detail=source_detail, | ||
tags=tags, | ||
parents=parents, | ||
**kwargs | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
"""Contains the SageMaker Training Job class.""" | ||
from smexperiments import _base_types, api_types | ||
|
||
|
||
class TrainingJob(_base_types.Record): | ||
@classmethod | ||
def search( | ||
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None, | ||
): | ||
""" | ||
Search Training Job. Returns SearchResults in the account matching the search criteria. | ||
|
||
Args: | ||
search_expression: (dict, optional): A Boolean conditional statement. Resource objects | ||
must satisfy this condition to be included in search results. You must provide at | ||
least one subexpression, filter, or nested filter. | ||
sort_by (str, optional): The name of the resource property used to sort the SearchResults. | ||
The default is LastModifiedTime | ||
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or | ||
Descending . The default is Descending . | ||
max_results (int, optional): The maximum number of results to return in a SearchResponse. | ||
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not | ||
supplied, a default boto3 client will be used. | ||
|
||
Returns: | ||
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria. | ||
""" | ||
return super(TrainingJob, cls)._search( | ||
search_resource="TrainingJob", | ||
search_item_factory=api_types.TrainingJobSearchResult.from_boto, | ||
search_expression=None if search_expression is None else search_expression.to_boto(), | ||
sort_by=sort_by, | ||
sort_order=sort_order, | ||
max_results=max_results, | ||
sagemaker_boto_client=sagemaker_boto_client, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright 019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be done automatically, i created an issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
|
@@ -21,7 +21,6 @@ | |
|
||
@pytest.mark.slow | ||
def test_track_from_training_job(sagemaker_boto_client, training_job_name): | ||
training_job_name = "smexperiments-integ-eca5c064-3a64-433e-a30a-2963338d71d8" | ||
get_job = lambda: sagemaker_boto_client.describe_training_job(TrainingJobName=training_job_name) | ||
tj = get_job() | ||
source_arn = tj["TrainingJobArn"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import pytest | ||
|
||
from smexperiments.training_job import TrainingJob | ||
from smexperiments.search_expression import SearchExpression, Filter, Operator | ||
from tests.helpers import retry | ||
|
||
|
||
@pytest.mark.slow | ||
def test_search(sagemaker_boto_client, training_job_name, docker_image): | ||
def validate(): | ||
training_job_searched = [] | ||
search_filter = Filter(name="TrainingJobName", operator=Operator.EQUALS, value=training_job_name) | ||
search_expression = SearchExpression(filters=[search_filter]) | ||
for s in TrainingJob.search( | ||
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client, | ||
): | ||
training_job_searched.append(s) | ||
|
||
assert len(training_job_searched) == 1 | ||
assert training_job_searched[0].training_job_name == training_job_name | ||
assert training_job_searched[0].input_data_config[0]["ChannelName"] == "train" | ||
assert training_job_searched[0].algorithm_specification == { | ||
"TrainingImage": docker_image, | ||
"TrainingInputMode": "File", | ||
} | ||
assert training_job_searched[0].resource_config == { | ||
"InstanceType": "ml.m5.large", | ||
"InstanceCount": 1, | ||
"VolumeSizeInGB": 10, | ||
} | ||
assert training_job_searched[0].stopping_condition == {"MaxRuntimeInSeconds": 900} | ||
assert training_job_searched # sanity test | ||
|
||
retry(validate) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import pytest | ||
import unittest.mock | ||
|
||
from smexperiments import training_job, api_types | ||
|
||
|
||
@pytest.fixture | ||
def sagemaker_boto_client(): | ||
return unittest.mock.Mock() | ||
|
||
|
||
def test_search(sagemaker_boto_client): | ||
sagemaker_boto_client.search.return_value = { | ||
"Results": [ | ||
{ | ||
"TrainingJob": { | ||
"TrainingJobName": "training-1", | ||
"TrainingJobArn": "arn::training-1", | ||
"HyperParameters": {"learning_rate": "0.1"}, | ||
} | ||
}, | ||
{ | ||
"TrainingJob": { | ||
"TrainingJobName": "training-2", | ||
"TrainingJobArn": "arn::training-2", | ||
"HyperParameters": {"learning_rate": "0.2"}, | ||
} | ||
}, | ||
] | ||
} | ||
expected = [ | ||
api_types.TrainingJobSearchResult( | ||
training_job_name="training-1", | ||
training_job_arn="arn::training-1", | ||
hyper_parameters={"learning_rate": "0.1"}, | ||
), | ||
api_types.TrainingJobSearchResult( | ||
training_job_name="training-2", | ||
training_job_arn="arn::training-2", | ||
hyper_parameters={"learning_rate": "0.2"}, | ||
), | ||
] | ||
assert expected == list(training_job.TrainingJob.search(sagemaker_boto_client=sagemaker_boto_client)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you construct this list of fields? what about some other fields like SecondaryStatus and FailureReason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the ones I think are most useful, but others should be covered by **kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems incomplete tho, we will get back all the fields in search regardless of what is in kwargs. i guess we can follow up later and add the remaining fields.