|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | + |
| 14 | +"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow.""" |
| 15 | + |
| 16 | +from __future__ import absolute_import |
| 17 | + |
| 18 | +import os |
| 19 | +import platform |
| 20 | +import re |
| 21 | +from typing import Set, Tuple, List, Dict, Generator |
| 22 | +import boto3 |
| 23 | +import mlflow |
| 24 | +from mlflow import MlflowClient |
| 25 | +from mlflow.entities import Metric, Param, RunTag |
| 26 | + |
| 27 | +from packaging import version |
| 28 | + |
| 29 | + |
| 30 | +def encode(name: str, existing_names: Set[str]) -> str: |
| 31 | + """Encode a string to comply with MLflow naming restrictions and ensure uniqueness. |
| 32 | +
|
| 33 | + Args: |
| 34 | + name (str): The original string to be encoded. |
| 35 | + existing_names (Set[str]): Set of existing encoded names to avoid collisions. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + str: The encoded string if changes were necessary, otherwise the original string. |
| 39 | + """ |
| 40 | + |
| 41 | + def encode_char(match): |
| 42 | + return f"_{ord(match.group(0)):02x}_" |
| 43 | + |
| 44 | + # Check if we're on Mac/Unix and using MLflow 2.16.0 or greater |
| 45 | + is_unix = platform.system() != "Windows" |
| 46 | + mlflow_version = version.parse(mlflow.__version__) |
| 47 | + allow_colon = is_unix and mlflow_version >= version.parse("2.16.0") |
| 48 | + |
| 49 | + if allow_colon: |
| 50 | + pattern = r"[^\w\-./:\s]" |
| 51 | + else: |
| 52 | + pattern = r"[^\w\-./\s]" |
| 53 | + |
| 54 | + encoded = re.sub(pattern, encode_char, name) |
| 55 | + base_name = encoded[:240] # Leave room for potential suffix to accommodate duplicates |
| 56 | + |
| 57 | + if base_name in existing_names: |
| 58 | + suffix = 1 |
| 59 | + # Edge case where even with suffix space there is a collision |
| 60 | + # we will override one of the keys. |
| 61 | + while f"{base_name}_{suffix}" in existing_names: |
| 62 | + suffix += 1 |
| 63 | + encoded = f"{base_name}_{suffix}" |
| 64 | + |
| 65 | + # Max length is 250 for mlflow metric/params |
| 66 | + encoded = encoded[:250] |
| 67 | + |
| 68 | + existing_names.add(encoded) |
| 69 | + return encoded |
| 70 | + |
| 71 | + |
| 72 | +def decode(encoded_metric_name: str) -> str: |
| 73 | + """Decodes an encoded metric name by replacing hexadecimal representations with ASCII |
| 74 | +
|
| 75 | + This function reverses the encoding process by converting hexadecimal codes |
| 76 | + back to their original characters. It looks for patterns of the form "_XX_" |
| 77 | + where XX is a two-digit hexadecimal code, and replaces them with the |
| 78 | + corresponding ASCII character. |
| 79 | +
|
| 80 | + Args: |
| 81 | + encoded_metric_name (str): The encoded metric name to be decoded. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + str: The decoded metric name with hexadecimal codes replaced by their |
| 85 | + corresponding characters. |
| 86 | +
|
| 87 | + Example: |
| 88 | + >>> decode("loss_3a_val") |
| 89 | + "loss:val" |
| 90 | + """ |
| 91 | + |
| 92 | + def replace_code(match): |
| 93 | + code = match.group(1) |
| 94 | + return chr(int(code, 16)) |
| 95 | + |
| 96 | + # Replace encoded characters |
| 97 | + decoded = re.sub(r"_([0-9a-f]{2})_", replace_code, encoded_metric_name) |
| 98 | + |
| 99 | + return decoded |
| 100 | + |
| 101 | + |
| 102 | +def get_training_job_details(job_arn: str) -> dict: |
| 103 | + """Retrieve details of a SageMaker training job. |
| 104 | +
|
| 105 | + Args: |
| 106 | + job_arn (str): The ARN of the SageMaker training job. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + dict: A dictionary containing the details of the training job. |
| 110 | +
|
| 111 | + Raises: |
| 112 | + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. |
| 113 | + """ |
| 114 | + sagemaker_client = boto3.client("sagemaker") |
| 115 | + job_name = job_arn.split("/")[-1] |
| 116 | + return sagemaker_client.describe_training_job(TrainingJobName=job_name) |
| 117 | + |
| 118 | + |
| 119 | +def create_metric_queries(job_arn: str, metric_definitions: list) -> list: |
| 120 | + """Create metric queries for SageMaker metrics. |
| 121 | +
|
| 122 | + Args: |
| 123 | + job_arn (str): The ARN of the SageMaker training job. |
| 124 | + metric_definitions (list): List of metric definitions from the training job. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + list: A list of dictionaries, each representing a metric query. |
| 128 | + """ |
| 129 | + metric_queries = [] |
| 130 | + for metric in metric_definitions: |
| 131 | + query = { |
| 132 | + "MetricName": metric["Name"], |
| 133 | + "XAxisType": "Timestamp", |
| 134 | + "MetricStat": "Avg", |
| 135 | + "Period": "OneMinute", |
| 136 | + "ResourceArn": job_arn, |
| 137 | + } |
| 138 | + metric_queries.append(query) |
| 139 | + return metric_queries |
| 140 | + |
| 141 | + |
| 142 | +def get_metric_data(metric_queries: list) -> dict: |
| 143 | + """Retrieve metric data from SageMaker. |
| 144 | +
|
| 145 | + Args: |
| 146 | + metric_queries (list): A list of metric queries. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + dict: A dictionary containing the metric data results. |
| 150 | +
|
| 151 | + Raises: |
| 152 | + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. |
| 153 | + """ |
| 154 | + sagemaker_metrics_client = boto3.client("sagemaker-metrics") |
| 155 | + metric_data = sagemaker_metrics_client.batch_get_metrics(MetricQueries=metric_queries) |
| 156 | + return metric_data |
| 157 | + |
| 158 | + |
| 159 | +def prepare_mlflow_metrics( |
| 160 | + metric_queries: list, metric_results: list |
| 161 | +) -> Tuple[List[Metric], Dict[str, str]]: |
| 162 | + """Prepare metrics for MLflow logging, encoding metric names if necessary. |
| 163 | +
|
| 164 | + Args: |
| 165 | + metric_queries (list): The original metric queries sent to SageMaker. |
| 166 | + metric_results (list): The metric results from SageMaker batch_get_metrics. |
| 167 | +
|
| 168 | + Returns: |
| 169 | + Tuple[List[Metric], Dict[str, str]]: |
| 170 | + - A list of Metric objects with encoded names (if necessary) |
| 171 | + - A mapping of encoded to original names for metrics (only for encoded metrics) |
| 172 | + """ |
| 173 | + mlflow_metrics = [] |
| 174 | + metric_name_mapping = {} |
| 175 | + existing_names = set() |
| 176 | + |
| 177 | + for query, result in zip(metric_queries, metric_results): |
| 178 | + if result["Status"] == "Complete": |
| 179 | + metric_name = query["MetricName"] |
| 180 | + encoded_name = encode(metric_name, existing_names) |
| 181 | + metric_name_mapping[encoded_name] = metric_name |
| 182 | + |
| 183 | + for step, (timestamp, value) in enumerate( |
| 184 | + zip(result["XAxisValues"], result["MetricValues"]) |
| 185 | + ): |
| 186 | + metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step) |
| 187 | + mlflow_metrics.append(metric) |
| 188 | + |
| 189 | + return mlflow_metrics, metric_name_mapping |
| 190 | + |
| 191 | + |
| 192 | +def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]: |
| 193 | + """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary. |
| 194 | +
|
| 195 | + Args: |
| 196 | + hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + Tuple[List[Param], Dict[str, str]]: |
| 200 | + - A list of Param objects with encoded names (if necessary) |
| 201 | + - A mapping of encoded to original names for |
| 202 | + hyperparameters (only for encoded parameters) |
| 203 | + """ |
| 204 | + mlflow_params = [] |
| 205 | + param_name_mapping = {} |
| 206 | + existing_names = set() |
| 207 | + |
| 208 | + for key, value in hyperparameters.items(): |
| 209 | + encoded_key = encode(key, existing_names) |
| 210 | + param_name_mapping[encoded_key] = key |
| 211 | + mlflow_params.append(Param(encoded_key, str(value))) |
| 212 | + |
| 213 | + return mlflow_params, param_name_mapping |
| 214 | + |
| 215 | + |
| 216 | +def batch_items(items: list, batch_size: int) -> Generator: |
| 217 | + """Yield successive batch_size chunks from items. |
| 218 | +
|
| 219 | + Args: |
| 220 | + items (list): The list of items to be batched. |
| 221 | + batch_size (int): The size of each batch. |
| 222 | +
|
| 223 | + Yields: |
| 224 | + list: A batch of items. |
| 225 | + """ |
| 226 | + for i in range(0, len(items), batch_size): |
| 227 | + yield items[i : i + batch_size] |
| 228 | + |
| 229 | + |
| 230 | +def log_to_mlflow(metrics: list, params: list, tags: dict) -> None: |
| 231 | + """Log metrics, parameters, and tags to MLflow. |
| 232 | +
|
| 233 | + Args: |
| 234 | + metrics (list): List of metrics to log. |
| 235 | + params (list): List of parameters to log. |
| 236 | + tags (dict): Dictionary of tags to set. |
| 237 | +
|
| 238 | + Raises: |
| 239 | + mlflow.exceptions.MlflowException: If there's an issue with MLflow logging. |
| 240 | + """ |
| 241 | + client = MlflowClient() |
| 242 | + |
| 243 | + experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME") |
| 244 | + if experiment_name is None or experiment_name.strip() == "": |
| 245 | + experiment_name = "Default" |
| 246 | + print("MLFLOW_EXPERIMENT_NAME not set. Using Default") |
| 247 | + |
| 248 | + experiment = client.get_experiment_by_name(experiment_name) |
| 249 | + if experiment is None: |
| 250 | + experiment_id = client.create_experiment(experiment_name) |
| 251 | + else: |
| 252 | + experiment_id = experiment.experiment_id |
| 253 | + |
| 254 | + run = client.create_run(experiment_id) |
| 255 | + |
| 256 | + for metric_batch in batch_items(metrics, 1000): |
| 257 | + client.log_batch( |
| 258 | + run.info.run_id, |
| 259 | + metrics=metric_batch, |
| 260 | + ) |
| 261 | + for param_batch in batch_items(params, 1000): |
| 262 | + client.log_batch(run.info.run_id, params=param_batch) |
| 263 | + |
| 264 | + tag_items = list(tags.items()) |
| 265 | + for tag_batch in batch_items(tag_items, 1000): |
| 266 | + tag_objects = [RunTag(key, str(value)) for key, value in tag_batch] |
| 267 | + client.log_batch(run.info.run_id, tags=tag_objects) |
| 268 | + client.set_terminated(run.info.run_id) |
| 269 | + |
| 270 | + |
| 271 | +def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None: |
| 272 | + """Retrieve SageMaker metrics and hyperparameters and log them to MLflow. |
| 273 | +
|
| 274 | + Args: |
| 275 | + training_job_arn (str): The ARN of the SageMaker training job. |
| 276 | +
|
| 277 | + Raises: |
| 278 | + Exception: If there's any error during the process. |
| 279 | + """ |
| 280 | + # Get training job details |
| 281 | + mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI")) |
| 282 | + job_details = get_training_job_details(training_job_arn) |
| 283 | + |
| 284 | + # Extract hyperparameters and metric definitions |
| 285 | + hyperparameters = job_details["HyperParameters"] |
| 286 | + metric_definitions = job_details["AlgorithmSpecification"]["MetricDefinitions"] |
| 287 | + |
| 288 | + # Create and get metric queries |
| 289 | + metric_queries = create_metric_queries(job_details["TrainingJobArn"], metric_definitions) |
| 290 | + metric_data = get_metric_data(metric_queries) |
| 291 | + |
| 292 | + # Create a mapping of encoded to original metric names |
| 293 | + # Prepare data for MLflow |
| 294 | + mlflow_metrics, metric_name_mapping = prepare_mlflow_metrics( |
| 295 | + metric_queries, metric_data["MetricQueryResults"] |
| 296 | + ) |
| 297 | + |
| 298 | + # Create a mapping of encoded to original hyperparameter names |
| 299 | + # Prepare data for MLflow |
| 300 | + mlflow_params, param_name_mapping = prepare_mlflow_params(hyperparameters) |
| 301 | + |
| 302 | + mlflow_tags = { |
| 303 | + "training_job_arn": training_job_arn, |
| 304 | + "metric_name_mapping": str(metric_name_mapping), |
| 305 | + "param_name_mapping": str(param_name_mapping), |
| 306 | + } |
| 307 | + |
| 308 | + # Log to MLflow |
| 309 | + log_to_mlflow(mlflow_metrics, mlflow_params, mlflow_tags) |
| 310 | + print(f"Logged {len(mlflow_metrics)} metric datapoints to MLflow") |
| 311 | + print(f"Logged {len(mlflow_params)} hyperparameters to MLflow") |
0 commit comments