diff --git a/src/smexperiments/api_types.py b/src/smexperiments/api_types.py index dfbacef..6b455a7 100644 --- a/src/smexperiments/api_types.py +++ b/src/smexperiments/api_types.py @@ -272,6 +272,85 @@ def __init__(self, code=None, message=None, metric_index=None, **kwargs): super(BatchPutMetricsError, self).__init__(code=code, message=message, metric_index=metric_index, **kwargs) +class TrainingJobSearchResult(_base_types.ApiObject): + """Summary model of an Training Job search result. + + Attributes: + training_job_name (str): The name of the training job. + training_job_arn (str): The Amazon Resource Name (ARN) of the training job. + tuning_job_arn (str): The Amazon Resource Name (ARN) of the associated. + hyperparameter tuning job if the training job was launched by a hyperparameter tuning job. + labeling_job_arn (str): The Amazon Resource Name (ARN) of the labeling job. + autoML_job_arn (str): The Amazon Resource Name (ARN) of the job. + model_artifacts (dict): Information about the Amazon S3 location that is configured for storing model artifacts. + training_job_status (str): The status of the training job + hyper_parameters (dict): Algorithm-specific parameters. + algorithm_specification (dict): Information about the algorithm used for training, and algorithm metadata. + input_data_config (dict): An array of Channel objects that describes each data input channel. + output_data_config (dict): The S3 path where model artifacts that you configured when creating the job are + stored. Amazon SageMaker creates subfolders for model artifacts. + resource_config (dict): Resources, including ML compute instances and ML storage volumes, that are configured + for model training. + debug_hook_config (dict): Configuration information for the debug hook parameters, collection configuration, + and storage paths. + debug_rule_config (dict): Information about the debug rule configuration. + """ + + training_job_name = None + training_job_arn = None + tuning_job_arn = None + labeling_job_arn = None + autoML_job_arn = None + model_artifacts = None + training_job_status = None + hyper_parameters = None + algorithm_specification = None + input_data_config = None + output_data_config = None + resource_config = None + debug_hook_config = None + experiment_config = None + debug_rule_config = None + + def __init__( + self, + training_job_arn=None, + training_job_name=None, + tuning_job_arn=None, + labeling_job_arn=None, + autoML_job_arn=None, + model_artifacts=None, + training_job_status=None, + hyper_parameters=None, + algorithm_specification=None, + input_data_config=None, + output_data_config=None, + resource_config=None, + debug_hook_config=None, + experiment_config=None, + debug_rule_config=None, + **kwargs + ): + super(TrainingJobSearchResult, self).__init__( + training_job_arn=training_job_arn, + training_job_name=training_job_name, + tuning_job_arn=tuning_job_arn, + labeling_job_arn=labeling_job_arn, + autoML_job_arn=autoML_job_arn, + model_artifacts=model_artifacts, + training_job_status=training_job_status, + hyper_parameters=hyper_parameters, + algorithm_specification=algorithm_specification, + input_data_config=input_data_config, + output_data_config=output_data_config, + resource_config=resource_config, + debug_hook_config=debug_hook_config, + experiment_config=experiment_config, + debug_rule_config=debug_rule_config, + **kwargs + ) + + class ExperimentSearchResult(_base_types.ApiObject): """Summary model of an Experiment search result. @@ -369,12 +448,6 @@ class TrialComponentSearchResult(_base_types.ApiObject): display_name = None source = None status = None - start_time = None - end_time = None - creation_time = None - created_by = None - last_modified_time = None - last_modified_by = None parameters = None input_artifacts = None output_artifacts = None @@ -392,10 +465,6 @@ def __init__( display_name=None, source=None, status=None, - creation_time=None, - created_by=None, - last_modified_time=None, - last_modified_by=None, parameters=None, input_artifacts=None, output_artifacts=None, @@ -403,6 +472,7 @@ def __init__( source_detail=None, tags=None, parents=None, + **kwargs ): super(TrialComponentSearchResult, self).__init__( trial_component_arn=trial_component_arn, @@ -410,12 +480,6 @@ def __init__( display_name=display_name, source=source, status=status, - start_time=start_time, - end_time=end_time, - creation_time=creation_time, - created_by=created_by, - last_modified_by=last_modified_by, - last_modified_time=last_modified_time, parameters=parameters, input_artifacts=input_artifacts, output_artifacts=output_artifacts, @@ -423,4 +487,5 @@ def __init__( source_detail=source_detail, tags=tags, parents=parents, + **kwargs ) diff --git a/src/smexperiments/training_job.py b/src/smexperiments/training_job.py new file mode 100644 index 0000000..c213648 --- /dev/null +++ b/src/smexperiments/training_job.py @@ -0,0 +1,49 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Contains the SageMaker Training Job class.""" +from smexperiments import _base_types, api_types + + +class TrainingJob(_base_types.Record): + @classmethod + def search( + cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None, + ): + """ + Search Training Job. Returns SearchResults in the account matching the search criteria. + + Args: + search_expression: (dict, optional): A Boolean conditional statement. Resource objects + must satisfy this condition to be included in search results. You must provide at + least one subexpression, filter, or nested filter. + sort_by (str, optional): The name of the resource property used to sort the SearchResults. + The default is LastModifiedTime + sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or + Descending . The default is Descending . + max_results (int, optional): The maximum number of results to return in a SearchResponse. + sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not + supplied, a default boto3 client will be used. + + Returns: + collections.Iterator[SearchResult] : An iterator over search results matching the search criteria. + """ + return super(TrainingJob, cls)._search( + search_resource="TrainingJob", + search_item_factory=api_types.TrainingJobSearchResult.from_boto, + search_expression=None if search_expression is None else search_expression.to_boto(), + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + sagemaker_boto_client=sagemaker_boto_client, + ) diff --git a/tests/conftest.py b/tests/conftest.py index b0c54ae..6fc1a8f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -264,11 +264,7 @@ def training_job_name(sagemaker_boto_client, training_role_arn, docker_image, tr "DataSource": {"S3DataSource": {"S3Uri": training_s3_uri, "S3DataType": "S3Prefix"}}, } ], - AlgorithmSpecification={ - "TrainingImage": docker_image, - "TrainingInputMode": "File", - "EnableSageMakerMetricsTimeSeries": True, - }, + AlgorithmSpecification={"TrainingImage": docker_image, "TrainingInputMode": "File",}, RoleArn=training_role_arn, ResourceConfig={"InstanceType": "ml.m5.large", "InstanceCount": 1, "VolumeSizeInGB": 10}, StoppingCondition={"MaxRuntimeInSeconds": 900}, diff --git a/tests/integ/test_track_from_processing_job.py b/tests/integ/test_track_from_processing_job.py index e200c24..73d1826 100644 --- a/tests/integ/test_track_from_processing_job.py +++ b/tests/integ/test_track_from_processing_job.py @@ -20,12 +20,11 @@ @pytest.mark.slow def test_track_from_processing_job(sagemaker_boto_client, processing_job_name): - get_job = lambda: sagemaker_boto_client.describe_processing_job(ProcessingJobName=processing_job_name) processing_job = get_job() source_arn = processing_job["ProcessingJobArn"] - # wait_for_job(processing_job_name, get_job, "ProcessingJobStatus") + wait_for_job(processing_job_name, get_job, "ProcessingJobStatus") print(processing_job) if "ProcessingStartTime" in processing_job: diff --git a/tests/integ/test_track_from_training_job.py b/tests/integ/test_track_from_training_job.py index 448647d..0a762b1 100644 --- a/tests/integ/test_track_from_training_job.py +++ b/tests/integ/test_track_from_training_job.py @@ -1,4 +1,4 @@ -# Copyright 019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -21,7 +21,6 @@ @pytest.mark.slow def test_track_from_training_job(sagemaker_boto_client, training_job_name): - training_job_name = "smexperiments-integ-eca5c064-3a64-433e-a30a-2963338d71d8" get_job = lambda: sagemaker_boto_client.describe_training_job(TrainingJobName=training_job_name) tj = get_job() source_arn = tj["TrainingJobArn"] diff --git a/tests/integ/test_training_job.py b/tests/integ/test_training_job.py new file mode 100644 index 0000000..6dd7fcb --- /dev/null +++ b/tests/integ/test_training_job.py @@ -0,0 +1,46 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest + +from smexperiments.training_job import TrainingJob +from smexperiments.search_expression import SearchExpression, Filter, Operator +from tests.helpers import retry + + +@pytest.mark.slow +def test_search(sagemaker_boto_client, training_job_name, docker_image): + def validate(): + training_job_searched = [] + search_filter = Filter(name="TrainingJobName", operator=Operator.EQUALS, value=training_job_name) + search_expression = SearchExpression(filters=[search_filter]) + for s in TrainingJob.search( + search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client, + ): + training_job_searched.append(s) + + assert len(training_job_searched) == 1 + assert training_job_searched[0].training_job_name == training_job_name + assert training_job_searched[0].input_data_config[0]["ChannelName"] == "train" + assert training_job_searched[0].algorithm_specification == { + "TrainingImage": docker_image, + "TrainingInputMode": "File", + } + assert training_job_searched[0].resource_config == { + "InstanceType": "ml.m5.large", + "InstanceCount": 1, + "VolumeSizeInGB": 10, + } + assert training_job_searched[0].stopping_condition == {"MaxRuntimeInSeconds": 900} + assert training_job_searched # sanity test + + retry(validate) diff --git a/tests/unit/test_training_job.py b/tests/unit/test_training_job.py new file mode 100644 index 0000000..a2e31a2 --- /dev/null +++ b/tests/unit/test_training_job.py @@ -0,0 +1,55 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest +import unittest.mock + +from smexperiments import training_job, api_types + + +@pytest.fixture +def sagemaker_boto_client(): + return unittest.mock.Mock() + + +def test_search(sagemaker_boto_client): + sagemaker_boto_client.search.return_value = { + "Results": [ + { + "TrainingJob": { + "TrainingJobName": "training-1", + "TrainingJobArn": "arn::training-1", + "HyperParameters": {"learning_rate": "0.1"}, + } + }, + { + "TrainingJob": { + "TrainingJobName": "training-2", + "TrainingJobArn": "arn::training-2", + "HyperParameters": {"learning_rate": "0.2"}, + } + }, + ] + } + expected = [ + api_types.TrainingJobSearchResult( + training_job_name="training-1", + training_job_arn="arn::training-1", + hyper_parameters={"learning_rate": "0.1"}, + ), + api_types.TrainingJobSearchResult( + training_job_name="training-2", + training_job_arn="arn::training-2", + hyper_parameters={"learning_rate": "0.2"}, + ), + ] + assert expected == list(training_job.TrainingJob.search(sagemaker_boto_client=sagemaker_boto_client))