Skip to content

Commit 22c1f96

Browse files
authored
[perf, trainer, training_utils] fix: Try to montior with mlflow up to 3 times, and avoid duplicate key processing in each step. (verl-project#5548)
### What does this PR do? 1. MLFlow creation may fail: Use a `try ...` block to try up to `3` times, and skip it if still failing (User can still track with console, tensorboard, etc) 2. MLFlow processes key in each step, while it can be cached with minor memory usage. #### 2. After PR ``` $ grep "hostname\|MLflow" sliuxl-dsr1_llama_8b-verl-82457.out hostname=ip-10-4-130-58 WARNING: MLflow initialization attempt 1/3 failed: API request to https://us-west-2.experiments.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name failed with exception An error occurred (ExpiredToken) when calling the AssumeRole operation: The security token included in the request is expired WARNING: MLflow initialization attempt 2/3 failed: API request to https://us-west-2.experiments.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name failed with exception An error occurred (ExpiredToken) when calling the AssumeRole operation: The security token included in the request is expired WARNING: MLflow initialization attempt 3/3 failed: API request to https://us-west-2.experiments.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name failed with exception An error occurred (ExpiredToken) when calling the AssumeRole operation: The security token included in the request is expired WARNING: All MLflow initialization attempts failed. Proceeding without MLflow tracking. ``` #### 1. Before PR ``` $ grep "botocore.exceptions.ClientError: An error occurred (ExpiredToken) when calling the AssumeRole operation:" ../../*/*/*.out | awk -F ":" '{print $1}' | sort | xargs grep hostname | sort | uniq -c 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82416-sft-loss-02-j-03-seq04k-ns01-sp01-bs07-pad1--20260306.113347/sliuxl-dsr1_llama_8b-verl-82416.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82419-sft-loss-02-j-06-seq08k-ns01-sp01-bs03-pad1--20260306.113402/sliuxl-dsr1_llama_8b-verl-82419.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82420-sft-loss-02-j-07-seq16k-ns01-sp01-bs01-pad1--20260306.113407/sliuxl-dsr1_llama_8b-verl-82420.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82429-sft-base-02-j-04-seq08k-ns01-sp01-bs01-pad1--20260306.121819/sliuxl-dsr1_llama_8b-verl-82429.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82431-sft-base-02-j-06-seq08k-ns01-sp01-bs03-pad1--20260306.121829/sliuxl-dsr1_llama_8b-verl-82431.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k_64/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82438-sft-loss-03-j-04-seq08k-ns01-sp01-bs01-pad1--20260306.124113/sliuxl-dsr1_llama_8b-verl-82438.out:hostname=ip-10-4-130-58 2 ../../sliuxl-averl--dsr1_llama_8b--gsm8k_64/deepseek_r1_distill_llama_08b_sft_config.aws_pbtxt--82441-sft-loss-03-j-07-seq16k-ns01-sp01-bs01-pad1--20260306.124128/sliuxl-dsr1_llama_8b-verl-82441.out:hostname=ip-10-4-130-58 ``` Actual error: ``` 470 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra 471 [rank0]: _run_app( 472 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/utils.py", line 457, in _run_app 473 [rank0]: run_and_report( 474 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/utils.py", line 223, in run_and_report 475 [rank0]: raise ex 476 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/utils.py", line 220, in run_and_report 477 [rank0]: return func() 478 [rank0]: ^^^^^^ 479 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/utils.py", line 458, in <lambda> 480 [rank0]: lambda: hydra.run( 481 [rank0]: ^^^^^^^^^^ 482 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/_internal/hydra.py", line 132, in run 483 [rank0]: _ = ret.return_value 484 [rank0]: ^^^^^^^^^^^^^^^^ 485 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/core/utils.py", line 260, in return_value 486 [rank0]: raise self._return_value 487 [rank0]: File "/usr/local/lib/python3.12/site-packages/hydra/core/utils.py", line 186, in run_job 488 [rank0]: ret.return_value = task_function(task_cfg) 489 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^ 490 [rank0]: File "/AWSBedrockVerl/verl/trainer/fsdp_sft_trainer.py", line 1007, in main 491 [rank0]: run_sft(config) 492 [rank0]: File "/AWSBedrockVerl/verl/trainer/fsdp_sft_trainer.py", line 997, in run_sft 493 [rank0]: trainer.fit() 494 [rank0]: File "/AWSBedrockVerl/verl/trainer/fsdp_sft_trainer.py", line 831, in fit 495 [rank0]: tracking = Tracking( 496 [rank0]: ^^^^^^^^^ 497 [rank0]: File "/AWSBedrockVerl/verl/utils/tracking.py", line 122, in __init__ 498 [rank0]: experiment = mlflow.set_experiment(project_name) 499 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 500 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/tracking/fluent.py", line 194, in set_experiment 501 [rank0]: experiment = client.get_experiment_by_name(experiment_name) 502 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 503 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/tracking/_tracking_service/client.py", line 283, in get_experiment_by_name 504 [rank0]: return self.store.get_experiment_by_name(name) 505 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 506 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/store/tracking/rest_store.py", line 965, in get_experiment_by_name 507 [rank0]: response_proto = self._call_endpoint(GetExperimentByName, req_body) 508 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 509 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/store/tracking/rest_store.py", line 233, in _call_endpoint 510 [rank0]: return call_endpoint( 511 [rank0]: ^^^^^^^^^^^^^^ 512 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/utils/rest_utils.py", line 622, in call_endpoint 513 [rank0]: response = http_request(**call_kwargs) 514 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 515 [rank0]: File "/usr/local/lib/python3.12/site-packages/mlflow/utils/rest_utils.py", line 281, in http_request 516 [rank0]: raise MlflowException(f"API request to {url} failed with exception {e}") 517 [rank0]: mlflow.exceptions.MlflowException: API request to https://us-west-2.experiments.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name failed with exception An error occurred (ExpiredToken) when calling the Assu meRole operation: The security token included in the request is expired 518 [rank0]:[W306 20:20:52.622856896 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/sta ble/distributed.html#shutdown (function operator()) ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Validated with actual E2E training experiments on `H100` ### API and Usage Example N.A. ### Design & Code Changes N.A.: Fixing existing issue and enhancing robustness / performance only ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [x] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 3bbdece commit 22c1f96

1 file changed

Lines changed: 55 additions & 30 deletions

File tree

verl/utils/tracking.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dataclasses
1919
import json
20+
import logging
2021
import os
2122
from enum import Enum
2223
from functools import partial
@@ -25,6 +26,11 @@
2526

2627
import orjson
2728

29+
logger = logging.getLogger(__name__)
30+
31+
MLFLOW_MAX_ATTEMPTS = 3
32+
MLFLOW_SLEEP_SECONDS = 5
33+
2834

2935
class Tracking:
3036
"""A unified tracking interface for logging experiment data to multiple backends.
@@ -82,25 +88,38 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st
8288

8389
if "mlflow" in default_backend:
8490
import os
91+
import time
8592

8693
import mlflow
8794

88-
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db")
89-
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
90-
91-
# Some cloud providers like Azure ML or Databricks automatically set MLFLOW_RUN_ID
92-
# If set, attach to the existing run instead of creating a new one
93-
run_id = os.environ.get("MLFLOW_RUN_ID")
94-
if run_id:
95-
mlflow.start_run(run_id=run_id)
96-
else:
97-
# Project_name is actually experiment_name in MLFlow
98-
# If experiment does not exist, will create a new experiment
99-
experiment = mlflow.set_experiment(project_name)
100-
mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)
101-
102-
mlflow.log_params(_compute_mlflow_params_from_objects(config))
103-
self.logger["mlflow"] = _MlflowLoggingAdapter()
95+
for _mlflow_attempt in range(1, MLFLOW_MAX_ATTEMPTS + 1):
96+
try:
97+
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db")
98+
logger.info("Using MLFlow tracking URI: %s", MLFLOW_TRACKING_URI)
99+
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
100+
101+
# Some cloud providers like Azure ML or Databricks automatically set MLFLOW_RUN_ID
102+
# If set, attach to the existing run instead of creating a new one
103+
run_id = os.environ.get("MLFLOW_RUN_ID")
104+
if run_id:
105+
mlflow.start_run(run_id=run_id)
106+
else:
107+
# Project_name is actually experiment_name in MLFlow
108+
# If experiment does not exist, will create a new experiment
109+
experiment = mlflow.set_experiment(project_name)
110+
mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)
111+
112+
mlflow.log_params(_compute_mlflow_params_from_objects(config))
113+
self.logger["mlflow"] = _MlflowLoggingAdapter()
114+
break # Success
115+
except Exception as e:
116+
logger.warning(
117+
"MLflow initialization attempt %d/%d failed: %s", _mlflow_attempt, MLFLOW_MAX_ATTEMPTS, e
118+
)
119+
if _mlflow_attempt < MLFLOW_MAX_ATTEMPTS:
120+
time.sleep(MLFLOW_SLEEP_SECONDS)
121+
else:
122+
logger.warning("All MLflow initialization attempts failed. Proceeding without MLflow tracking.")
104123

105124
if "swanlab" in default_backend:
106125
import os
@@ -280,6 +299,8 @@ def __init__(self):
280299
import re
281300

282301
self.logger = logging.getLogger(__name__)
302+
# Suppress noisy "Found credentials from IAM Role" on every MLflow request
303+
logging.getLogger("botocore.credentials").setLevel(logging.WARNING)
283304
# MLflow metric key validation logic:
284305
# https://github.com/mlflow/mlflow/blob/master/mlflow/utils/validation.py#L157C12-L157C44
285306
# Only characters allowed: slashes, alphanumerics, underscores, periods, dashes, colons,
@@ -288,24 +309,28 @@ def __init__(self):
288309
r"[^/\w.\- :]"
289310
) # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces.
290311
self._consecutive_slashes_pattern = re.compile(r"/+")
312+
self._sanitized_key_cache = {}
313+
314+
def _sanitize_key(self, key):
315+
if key in self._sanitized_key_cache:
316+
return self._sanitized_key_cache[key] or key
317+
# First replace @ with _at_ for backward compatibility
318+
sanitized = key.replace("@", "_at_")
319+
# Replace consecutive slashes with a single slash (MLflow treats them as file paths)
320+
sanitized = self._consecutive_slashes_pattern.sub("/", sanitized)
321+
# Then replace any other invalid characters with _
322+
sanitized = self._invalid_chars_pattern.sub("_", sanitized)
323+
if sanitized == key:
324+
self._sanitized_key_cache[key] = None
325+
else:
326+
self.logger.warning("[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.", key, sanitized)
327+
self._sanitized_key_cache[key] = sanitized
328+
return sanitized
291329

292330
def log(self, data, step):
293331
import mlflow
294332

295-
def sanitize_key(key):
296-
# First replace @ with _at_ for backward compatibility
297-
sanitized = key.replace("@", "_at_")
298-
# Replace consecutive slashes with a single slash (MLflow treats them as file paths)
299-
sanitized = self._consecutive_slashes_pattern.sub("/", sanitized)
300-
# Then replace any other invalid characters with _
301-
sanitized = self._invalid_chars_pattern.sub("_", sanitized)
302-
if sanitized != key:
303-
self.logger.warning(
304-
"[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.", key, sanitized
305-
)
306-
return sanitized
307-
308-
results = {sanitize_key(k): v for k, v in data.items()}
333+
results = {self._sanitize_key(k): v for k, v in data.items()}
309334
mlflow.log_metrics(metrics=results, step=step)
310335

311336

0 commit comments

Comments
 (0)