Skip to content

Commit d31e397

Browse files
authored
Feat: Support logging of MLFlow metrics when network_isolation mode is enabled (#4880)
* feature: Support logging of MLFlow metrics when network_isolation mode is enabled * Fix pylint + doc check
1 parent e626647 commit d31e397

File tree

5 files changed

+627
-0
lines changed

5 files changed

+627
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"PyYAML~=6.0",
4949
"requests",
5050
"sagemaker-core>=1.0.0,<2.0.0",
51+
"sagemaker-mlflow",
5152
"schema",
5253
"smdebug_rulesconfig==1.0.1",
5354
"tblib>=1.7.0,<4",

src/sagemaker/estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
from sagemaker.workflow.parameters import ParameterString
108108
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
109109

110+
from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow
111+
110112
logger = logging.getLogger(__name__)
111113

112114

@@ -1366,8 +1368,14 @@ def fit(
13661368
experiment_config = check_and_get_run_experiment_config(experiment_config)
13671369
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
13681370
self.jobs.append(self.latest_training_job)
1371+
forward_to_mlflow_tracking_server = False
1372+
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation():
1373+
wait = True
1374+
forward_to_mlflow_tracking_server = True
13691375
if wait:
13701376
self.latest_training_job.wait(logs=logs)
1377+
if forward_to_mlflow_tracking_server:
1378+
log_sagemaker_job_to_mlflow(self.latest_training_job.name)
13711379

13721380
def _compilation_job_name(self):
13731381
"""Placeholder docstring"""
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow."""
15+
16+
from __future__ import absolute_import
17+
18+
import os
19+
import platform
20+
import re
21+
from typing import Set, Tuple, List, Dict, Generator
22+
import boto3
23+
import mlflow
24+
from mlflow import MlflowClient
25+
from mlflow.entities import Metric, Param, RunTag
26+
27+
from packaging import version
28+
29+
30+
def encode(name: str, existing_names: Set[str]) -> str:
31+
"""Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
32+
33+
Args:
34+
name (str): The original string to be encoded.
35+
existing_names (Set[str]): Set of existing encoded names to avoid collisions.
36+
37+
Returns:
38+
str: The encoded string if changes were necessary, otherwise the original string.
39+
"""
40+
41+
def encode_char(match):
42+
return f"_{ord(match.group(0)):02x}_"
43+
44+
# Check if we're on Mac/Unix and using MLflow 2.16.0 or greater
45+
is_unix = platform.system() != "Windows"
46+
mlflow_version = version.parse(mlflow.__version__)
47+
allow_colon = is_unix and mlflow_version >= version.parse("2.16.0")
48+
49+
if allow_colon:
50+
pattern = r"[^\w\-./:\s]"
51+
else:
52+
pattern = r"[^\w\-./\s]"
53+
54+
encoded = re.sub(pattern, encode_char, name)
55+
base_name = encoded[:240] # Leave room for potential suffix to accommodate duplicates
56+
57+
if base_name in existing_names:
58+
suffix = 1
59+
# Edge case where even with suffix space there is a collision
60+
# we will override one of the keys.
61+
while f"{base_name}_{suffix}" in existing_names:
62+
suffix += 1
63+
encoded = f"{base_name}_{suffix}"
64+
65+
# Max length is 250 for mlflow metric/params
66+
encoded = encoded[:250]
67+
68+
existing_names.add(encoded)
69+
return encoded
70+
71+
72+
def decode(encoded_metric_name: str) -> str:
73+
"""Decodes an encoded metric name by replacing hexadecimal representations with ASCII
74+
75+
This function reverses the encoding process by converting hexadecimal codes
76+
back to their original characters. It looks for patterns of the form "_XX_"
77+
where XX is a two-digit hexadecimal code, and replaces them with the
78+
corresponding ASCII character.
79+
80+
Args:
81+
encoded_metric_name (str): The encoded metric name to be decoded.
82+
83+
Returns:
84+
str: The decoded metric name with hexadecimal codes replaced by their
85+
corresponding characters.
86+
87+
Example:
88+
>>> decode("loss_3a_val")
89+
"loss:val"
90+
"""
91+
92+
def replace_code(match):
93+
code = match.group(1)
94+
return chr(int(code, 16))
95+
96+
# Replace encoded characters
97+
decoded = re.sub(r"_([0-9a-f]{2})_", replace_code, encoded_metric_name)
98+
99+
return decoded
100+
101+
102+
def get_training_job_details(job_arn: str) -> dict:
103+
"""Retrieve details of a SageMaker training job.
104+
105+
Args:
106+
job_arn (str): The ARN of the SageMaker training job.
107+
108+
Returns:
109+
dict: A dictionary containing the details of the training job.
110+
111+
Raises:
112+
boto3.exceptions.Boto3Error: If there's an issue with the AWS API call.
113+
"""
114+
sagemaker_client = boto3.client("sagemaker")
115+
job_name = job_arn.split("/")[-1]
116+
return sagemaker_client.describe_training_job(TrainingJobName=job_name)
117+
118+
119+
def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
120+
"""Create metric queries for SageMaker metrics.
121+
122+
Args:
123+
job_arn (str): The ARN of the SageMaker training job.
124+
metric_definitions (list): List of metric definitions from the training job.
125+
126+
Returns:
127+
list: A list of dictionaries, each representing a metric query.
128+
"""
129+
metric_queries = []
130+
for metric in metric_definitions:
131+
query = {
132+
"MetricName": metric["Name"],
133+
"XAxisType": "Timestamp",
134+
"MetricStat": "Avg",
135+
"Period": "OneMinute",
136+
"ResourceArn": job_arn,
137+
}
138+
metric_queries.append(query)
139+
return metric_queries
140+
141+
142+
def get_metric_data(metric_queries: list) -> dict:
143+
"""Retrieve metric data from SageMaker.
144+
145+
Args:
146+
metric_queries (list): A list of metric queries.
147+
148+
Returns:
149+
dict: A dictionary containing the metric data results.
150+
151+
Raises:
152+
boto3.exceptions.Boto3Error: If there's an issue with the AWS API call.
153+
"""
154+
sagemaker_metrics_client = boto3.client("sagemaker-metrics")
155+
metric_data = sagemaker_metrics_client.batch_get_metrics(MetricQueries=metric_queries)
156+
return metric_data
157+
158+
159+
def prepare_mlflow_metrics(
160+
metric_queries: list, metric_results: list
161+
) -> Tuple[List[Metric], Dict[str, str]]:
162+
"""Prepare metrics for MLflow logging, encoding metric names if necessary.
163+
164+
Args:
165+
metric_queries (list): The original metric queries sent to SageMaker.
166+
metric_results (list): The metric results from SageMaker batch_get_metrics.
167+
168+
Returns:
169+
Tuple[List[Metric], Dict[str, str]]:
170+
- A list of Metric objects with encoded names (if necessary)
171+
- A mapping of encoded to original names for metrics (only for encoded metrics)
172+
"""
173+
mlflow_metrics = []
174+
metric_name_mapping = {}
175+
existing_names = set()
176+
177+
for query, result in zip(metric_queries, metric_results):
178+
if result["Status"] == "Complete":
179+
metric_name = query["MetricName"]
180+
encoded_name = encode(metric_name, existing_names)
181+
metric_name_mapping[encoded_name] = metric_name
182+
183+
for step, (timestamp, value) in enumerate(
184+
zip(result["XAxisValues"], result["MetricValues"])
185+
):
186+
metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
187+
mlflow_metrics.append(metric)
188+
189+
return mlflow_metrics, metric_name_mapping
190+
191+
192+
def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]:
193+
"""Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
194+
195+
Args:
196+
hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
197+
198+
Returns:
199+
Tuple[List[Param], Dict[str, str]]:
200+
- A list of Param objects with encoded names (if necessary)
201+
- A mapping of encoded to original names for
202+
hyperparameters (only for encoded parameters)
203+
"""
204+
mlflow_params = []
205+
param_name_mapping = {}
206+
existing_names = set()
207+
208+
for key, value in hyperparameters.items():
209+
encoded_key = encode(key, existing_names)
210+
param_name_mapping[encoded_key] = key
211+
mlflow_params.append(Param(encoded_key, str(value)))
212+
213+
return mlflow_params, param_name_mapping
214+
215+
216+
def batch_items(items: list, batch_size: int) -> Generator:
217+
"""Yield successive batch_size chunks from items.
218+
219+
Args:
220+
items (list): The list of items to be batched.
221+
batch_size (int): The size of each batch.
222+
223+
Yields:
224+
list: A batch of items.
225+
"""
226+
for i in range(0, len(items), batch_size):
227+
yield items[i : i + batch_size]
228+
229+
230+
def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
231+
"""Log metrics, parameters, and tags to MLflow.
232+
233+
Args:
234+
metrics (list): List of metrics to log.
235+
params (list): List of parameters to log.
236+
tags (dict): Dictionary of tags to set.
237+
238+
Raises:
239+
mlflow.exceptions.MlflowException: If there's an issue with MLflow logging.
240+
"""
241+
client = MlflowClient()
242+
243+
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
244+
if experiment_name is None or experiment_name.strip() == "":
245+
experiment_name = "Default"
246+
print("MLFLOW_EXPERIMENT_NAME not set. Using Default")
247+
248+
experiment = client.get_experiment_by_name(experiment_name)
249+
if experiment is None:
250+
experiment_id = client.create_experiment(experiment_name)
251+
else:
252+
experiment_id = experiment.experiment_id
253+
254+
run = client.create_run(experiment_id)
255+
256+
for metric_batch in batch_items(metrics, 1000):
257+
client.log_batch(
258+
run.info.run_id,
259+
metrics=metric_batch,
260+
)
261+
for param_batch in batch_items(params, 1000):
262+
client.log_batch(run.info.run_id, params=param_batch)
263+
264+
tag_items = list(tags.items())
265+
for tag_batch in batch_items(tag_items, 1000):
266+
tag_objects = [RunTag(key, str(value)) for key, value in tag_batch]
267+
client.log_batch(run.info.run_id, tags=tag_objects)
268+
client.set_terminated(run.info.run_id)
269+
270+
271+
def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None:
272+
"""Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
273+
274+
Args:
275+
training_job_arn (str): The ARN of the SageMaker training job.
276+
277+
Raises:
278+
Exception: If there's any error during the process.
279+
"""
280+
# Get training job details
281+
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
282+
job_details = get_training_job_details(training_job_arn)
283+
284+
# Extract hyperparameters and metric definitions
285+
hyperparameters = job_details["HyperParameters"]
286+
metric_definitions = job_details["AlgorithmSpecification"]["MetricDefinitions"]
287+
288+
# Create and get metric queries
289+
metric_queries = create_metric_queries(job_details["TrainingJobArn"], metric_definitions)
290+
metric_data = get_metric_data(metric_queries)
291+
292+
# Create a mapping of encoded to original metric names
293+
# Prepare data for MLflow
294+
mlflow_metrics, metric_name_mapping = prepare_mlflow_metrics(
295+
metric_queries, metric_data["MetricQueryResults"]
296+
)
297+
298+
# Create a mapping of encoded to original hyperparameter names
299+
# Prepare data for MLflow
300+
mlflow_params, param_name_mapping = prepare_mlflow_params(hyperparameters)
301+
302+
mlflow_tags = {
303+
"training_job_arn": training_job_arn,
304+
"metric_name_mapping": str(metric_name_mapping),
305+
"param_name_mapping": str(param_name_mapping),
306+
}
307+
308+
# Log to MLflow
309+
log_to_mlflow(mlflow_metrics, mlflow_params, mlflow_tags)
310+
print(f"Logged {len(mlflow_metrics)} metric datapoints to MLflow")
311+
print(f"Logged {len(mlflow_params)} hyperparameters to MLflow")

0 commit comments

Comments
 (0)