Skip to content

Commit 0324152

Browse files
author
pialidas
committed
PR comments addressed
1 parent baad38e commit 0324152

File tree

3 files changed

+28
-25
lines changed

3 files changed

+28
-25
lines changed

src/sagemaker/automl/automl.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# %load /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker/automl/automl.py
21
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License"). You
@@ -102,60 +101,64 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
102101
self.latest_auto_ml_job.wait(logs=logs)
103102

104103
@classmethod
105-
def attach(cls, job_name, sagemaker_session=None):
104+
def attach(cls, auto_ml_job_name, sagemaker_session=None):
106105
"""Attach to an existing AutoML job.
107106
107+
Creates and returns a AutoML bound to an existing automl job.
108+
108109
Args:
109-
job_name (str): AutoML job name
110+
auto_ml_job_name (str): AutoML job name
110111
sagemaker_session (sagemaker.session.Session): A SageMaker Session
111112
object, used for SageMaker interactions (default: None). If not
112-
specified, the one originally associated with the ``AutoML`` instance is used.:
113+
specified, the one originally associated with the ``AutoML`` instance is used.
113114
114115
Returns:
116+
sagemaker.automl.AutoML: A ``AutoML`` instance with the attached automl job.
115117
116118
"""
117119
sagemaker_session = sagemaker_session or Session()
118120

119-
_auto_ml_job_desc = sagemaker_session.describe_auto_ml_job(job_name)
121+
auto_ml_job_desc = sagemaker_session.describe_auto_ml_job(auto_ml_job_name)
120122
automl_job_tags = sagemaker_session.sagemaker_client.list_tags(
121-
ResourceArn=_auto_ml_job_desc["AutoMLJobArn"]
123+
ResourceArn=auto_ml_job_desc["AutoMLJobArn"]
122124
)["Tags"]
123125

124126
amlj = AutoML(
125-
role=_auto_ml_job_desc["RoleArn"],
126-
target_attribute_name=_auto_ml_job_desc["InputDataConfig"][0]["TargetAttributeName"],
127-
output_kms_key=_auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"),
128-
output_path=_auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"],
129-
base_job_name=job_name,
130-
compression_type=_auto_ml_job_desc["InputDataConfig"][0].get("CompressionType"),
127+
role=auto_ml_job_desc["RoleArn"],
128+
target_attribute_name=auto_ml_job_desc["InputDataConfig"][0]["TargetAttributeName"],
129+
output_kms_key=auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"),
130+
output_path=auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"],
131+
base_job_name=auto_ml_job_name,
132+
compression_type=auto_ml_job_desc["InputDataConfig"][0].get("CompressionType"),
131133
sagemaker_session=sagemaker_session,
132-
volume_kms_key=_auto_ml_job_desc.get("AutoMLJobConfig", {})
134+
volume_kms_key=auto_ml_job_desc.get("AutoMLJobConfig", {})
133135
.get("SecurityConfig", {})
134136
.get("VolumeKmsKeyId"),
135-
encrypt_inter_container_traffic=_auto_ml_job_desc.get("AutoMLJobConfig", {})
137+
encrypt_inter_container_traffic=auto_ml_job_desc.get("AutoMLJobConfig", {})
136138
.get("SecurityConfig", {})
137139
.get("EnableInterContainerTrafficEncryption", False),
138-
vpc_config=_auto_ml_job_desc.get("AutoMLJobConfig", {})
140+
vpc_config=auto_ml_job_desc.get("AutoMLJobConfig", {})
139141
.get("SecurityConfig", {})
140142
.get("VpcConfig"),
141-
problem_type=_auto_ml_job_desc.get("ProblemType"),
142-
max_candidates=_auto_ml_job_desc.get("AutoMLJobConfig", {})
143+
problem_type=auto_ml_job_desc.get("ProblemType"),
144+
max_candidates=auto_ml_job_desc.get("AutoMLJobConfig", {})
143145
.get("CompletionCriteria", {})
144146
.get("MaxCandidates"),
145-
max_runtime_per_training_job_in_seconds=_auto_ml_job_desc.get("AutoMLJobConfig", {})
147+
max_runtime_per_training_job_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {})
146148
.get("CompletionCriteria", {})
147149
.get("MaxRuntimePerTrainingJobInSeconds"),
148-
total_job_runtime_in_seconds=_auto_ml_job_desc.get("AutoMLJobConfig", {})
150+
total_job_runtime_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {})
149151
.get("CompletionCriteria", {})
150152
.get("MaxAutoMLJobRuntimeInSeconds"),
151-
job_objective=_auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"),
152-
generate_candidate_definitions_only=_auto_ml_job_desc.get(
153+
job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"),
154+
generate_candidate_definitions_only=auto_ml_job_desc.get(
153155
"GenerateCandidateDefinitionsOnly", False
154156
),
155157
tags=automl_job_tags,
156158
)
157-
amlj.current_job_name = job_name
158-
amlj._auto_ml_job_desc = _auto_ml_job_desc
159+
amlj.current_job_name = auto_ml_job_name
160+
amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201
161+
amlj._auto_ml_job_desc = auto_ml_job_desc
159162
return amlj
160163

161164
def describe_auto_ml_job(self, job_name=None):

tests/integ/test_auto_ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def test_deploy_best_candidate(sagemaker_session, cpu_instance_type):
286286
def test_create_model_best_candidate(sagemaker_session, cpu_instance_type):
287287
auto_ml_utils.create_auto_ml_job_if_not_exist(sagemaker_session)
288288

289-
auto_ml = AutoML.attach(job_name=AUTO_ML_JOB_NAME, sagemaker_session=sagemaker_session)
289+
auto_ml = AutoML.attach(auto_ml_job_name=AUTO_ML_JOB_NAME, sagemaker_session=sagemaker_session)
290290
best_candidate = auto_ml.best_candidate()
291291

292292
with timeout(minutes=5):

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def test_create_model(sagemaker_session):
657657

658658

659659
def test_attach(sagemaker_session):
660-
aml = AutoML.attach(job_name=JOB_NAME_3, sagemaker_session=sagemaker_session)
660+
aml = AutoML.attach(auto_ml_job_name=JOB_NAME_3, sagemaker_session=sagemaker_session)
661661
assert aml.current_job_name == JOB_NAME_3
662662
assert aml.role == "mock_role_arn"
663663
assert aml.target_attribute_name == "y"

0 commit comments

Comments
 (0)