|
1 |
| -# %load /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker/automl/automl.py |
2 | 1 | # Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 | 2 | #
|
4 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You
|
@@ -102,60 +101,64 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
|
102 | 101 | self.latest_auto_ml_job.wait(logs=logs)
|
103 | 102 |
|
104 | 103 | @classmethod
|
105 |
| - def attach(cls, job_name, sagemaker_session=None): |
| 104 | + def attach(cls, auto_ml_job_name, sagemaker_session=None): |
106 | 105 | """Attach to an existing AutoML job.
|
107 | 106 |
|
| 107 | + Creates and returns a AutoML bound to an existing automl job. |
| 108 | +
|
108 | 109 | Args:
|
109 |
| - job_name (str): AutoML job name |
| 110 | + auto_ml_job_name (str): AutoML job name |
110 | 111 | sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
111 | 112 | object, used for SageMaker interactions (default: None). If not
|
112 |
| - specified, the one originally associated with the ``AutoML`` instance is used.: |
| 113 | + specified, the one originally associated with the ``AutoML`` instance is used. |
113 | 114 |
|
114 | 115 | Returns:
|
| 116 | + sagemaker.automl.AutoML: A ``AutoML`` instance with the attached automl job. |
115 | 117 |
|
116 | 118 | """
|
117 | 119 | sagemaker_session = sagemaker_session or Session()
|
118 | 120 |
|
119 |
| - _auto_ml_job_desc = sagemaker_session.describe_auto_ml_job(job_name) |
| 121 | + auto_ml_job_desc = sagemaker_session.describe_auto_ml_job(auto_ml_job_name) |
120 | 122 | automl_job_tags = sagemaker_session.sagemaker_client.list_tags(
|
121 |
| - ResourceArn=_auto_ml_job_desc["AutoMLJobArn"] |
| 123 | + ResourceArn=auto_ml_job_desc["AutoMLJobArn"] |
122 | 124 | )["Tags"]
|
123 | 125 |
|
124 | 126 | amlj = AutoML(
|
125 |
| - role=_auto_ml_job_desc["RoleArn"], |
126 |
| - target_attribute_name=_auto_ml_job_desc["InputDataConfig"][0]["TargetAttributeName"], |
127 |
| - output_kms_key=_auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"), |
128 |
| - output_path=_auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"], |
129 |
| - base_job_name=job_name, |
130 |
| - compression_type=_auto_ml_job_desc["InputDataConfig"][0].get("CompressionType"), |
| 127 | + role=auto_ml_job_desc["RoleArn"], |
| 128 | + target_attribute_name=auto_ml_job_desc["InputDataConfig"][0]["TargetAttributeName"], |
| 129 | + output_kms_key=auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"), |
| 130 | + output_path=auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"], |
| 131 | + base_job_name=auto_ml_job_name, |
| 132 | + compression_type=auto_ml_job_desc["InputDataConfig"][0].get("CompressionType"), |
131 | 133 | sagemaker_session=sagemaker_session,
|
132 |
| - volume_kms_key=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 134 | + volume_kms_key=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
133 | 135 | .get("SecurityConfig", {})
|
134 | 136 | .get("VolumeKmsKeyId"),
|
135 |
| - encrypt_inter_container_traffic=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 137 | + encrypt_inter_container_traffic=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
136 | 138 | .get("SecurityConfig", {})
|
137 | 139 | .get("EnableInterContainerTrafficEncryption", False),
|
138 |
| - vpc_config=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 140 | + vpc_config=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
139 | 141 | .get("SecurityConfig", {})
|
140 | 142 | .get("VpcConfig"),
|
141 |
| - problem_type=_auto_ml_job_desc.get("ProblemType"), |
142 |
| - max_candidates=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 143 | + problem_type=auto_ml_job_desc.get("ProblemType"), |
| 144 | + max_candidates=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
143 | 145 | .get("CompletionCriteria", {})
|
144 | 146 | .get("MaxCandidates"),
|
145 |
| - max_runtime_per_training_job_in_seconds=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 147 | + max_runtime_per_training_job_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
146 | 148 | .get("CompletionCriteria", {})
|
147 | 149 | .get("MaxRuntimePerTrainingJobInSeconds"),
|
148 |
| - total_job_runtime_in_seconds=_auto_ml_job_desc.get("AutoMLJobConfig", {}) |
| 150 | + total_job_runtime_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {}) |
149 | 151 | .get("CompletionCriteria", {})
|
150 | 152 | .get("MaxAutoMLJobRuntimeInSeconds"),
|
151 |
| - job_objective=_auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"), |
152 |
| - generate_candidate_definitions_only=_auto_ml_job_desc.get( |
| 153 | + job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"), |
| 154 | + generate_candidate_definitions_only=auto_ml_job_desc.get( |
153 | 155 | "GenerateCandidateDefinitionsOnly", False
|
154 | 156 | ),
|
155 | 157 | tags=automl_job_tags,
|
156 | 158 | )
|
157 |
| - amlj.current_job_name = job_name |
158 |
| - amlj._auto_ml_job_desc = _auto_ml_job_desc |
| 159 | + amlj.current_job_name = auto_ml_job_name |
| 160 | + amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201 |
| 161 | + amlj._auto_ml_job_desc = auto_ml_job_desc |
159 | 162 | return amlj
|
160 | 163 |
|
161 | 164 | def describe_auto_ml_job(self, job_name=None):
|
|
0 commit comments