Skip to content

Conversation

@szaher
Copy link
Contributor

@szaher szaher commented Oct 3, 2025

Problem:
The previous implementation assigned default values directly to torchrun arguments. According to PyTorch's behavior, this prevents torchrun from reading its corresponding environment variables. Consequently, configurations provided by orchestration systems like Kubeflow (which inject variables like RANK and WORLD_SIZE) were being ignored.

Solution:
This commit refactors the argument handling logic to prioritize environment variables. The code now checks for the presence of the most common torchrun environment variables. If an environment variable is set, its value is used for the argument; otherwise, a default is applied.

This change aligns Training Hub's behavior with torchrun's intended design, making it compatible with standard distributed training environments.

Reference: https://docs.pytorch.org/docs/stable/elastic/run.html#environment-variables

Summary by CodeRabbit

  • New Features

    • Distributed training now auto-builds torchrun launch parameters from arguments or environment, including explicit master address and port; supports 'auto'/'gpu' for process count.
  • Refactor

    • Launch parameter handling centralized and normalized with clear precedence, validation, and conflict warnings/errors.
  • Breaking Changes

    • Multi-node params (process count and rendezvous ID) accept strings or integers; new optional master_addr and master_port parameters added.

…figuration

Problem:
The previous implementation assigned default values directly to torchrun arguments. According to PyTorch's behavior, this prevents torchrun from reading its corresponding environment variables. Consequently, configurations provided by orchestration systems like Kubeflow (which inject variables like RANK and WORLD_SIZE) were being ignored.

Solution:
This commit refactors the argument handling logic to prioritize environment variables. The code now checks for the presence of the most common torchrun environment variables. If an environment variable is set, its value is used for the argument; otherwise, a default is applied.

This change aligns Training Hub's behavior with torchrun's intended design, making it compatible with standard distributed training environments.

Reference: https://docs.pytorch.org/docs/stable/elastic/run.html#environment-variables

Signed-off-by: Saad Zaher <[email protected]>
@coderabbitai
Copy link

coderabbitai bot commented Oct 3, 2025

Walkthrough

SFT and OSFT now use a new training_hub.utils.get_torchrun_params helper to resolve and validate torchrun/distributed parameters (nproc_per_node, nnodes, node_rank, rdzv_id, rdzv_endpoint, master_addr, master_port) from args and environment. Public signatures widened to accept string|int unions and master_addr/master_port are propagated through training.

Changes

Cohort / File(s) Summary
New utility: torchrun param resolution
src/training_hub/utils.py
Add get_torchrun_params(args: dict) to read env & args, validate and normalize torchrun/distributed params, enforce precedence/conflict rules (including PET_-prefixed vars), handle 'auto'/'gpu' for nproc_per_node, and return a resolved dict used by training launch.
SFT: centralize torchrun params, widen typing, forward masters
src/training_hub/algorithms/sft.py
Import training_hub.utils, expand torchrun_keys, call get_torchrun_params(...) instead of manual defaults, construct TorchrunArgs from resolved params, widen nproc_per_node/rdzv_id types to `str
OSFT & MiniTrainer: centralize torchrun params, new master fields, widen typing
src/training_hub/algorithms/osft.py
Import Literal and get_torchrun_params; widen nproc_per_node to `Literal['auto','gpu']

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor User
    participant Env as Environment
    participant API as SFT/OSFT API
    participant Utils as get_torchrun_params()
    participant Torchrun as TorchrunArgs
    participant Runner as backend / run_training

    User->>API: call train(...) (may include nproc_per_node, rdzv_id, master_addr/port)
    API->>Utils: get_torchrun_params({torchrun_keys from args/env})
    Utils->>Env: read WORLD_SIZE/LOCAL_WORLD_SIZE/RANK, PET_* variants, master and rdzv envs
    Env-->>Utils: env values
    Utils-->>API: resolved torchrun params (source: args|env), or error on conflicts
    API->>API: merge/validate additional user torchrun_params (if provided)
    API->>Torchrun: construct TorchrunArgs(final params)
    API->>Runner: run_training(torch_args=TorchrunArgs, train_args=...)
    Runner-->>User: training job launched / execution started
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

I hop through envvars, sniff each line,
I merge your args with defaults, tidy and fine.
Masters and ports now travel with care,
Rendezvous settles, launches dance in the air.
Hop—deploy—training hums; the rabbit gives a flair. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and concisely summarizes the primary change of making Training Hub use torchrun environment variables for defaults, matching the code modifications and PR objectives.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Saad Zaher <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fc2175d and 0c873f5.

📒 Files selected for processing (1)
  • src/training_hub/algorithms/sft.py (2 hunks)

Copy link
Collaborator

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @szaher ! Could you please review the few comments on this PR? Otherwise, this looks good to merge.

Apply suggestion by rabbitcode to use new variables while keeping old
variables which are the only ones supported by kubeflow trainer

Signed-off-by: Saad Zaher <[email protected]>

# Set up torchrun arguments with single-node defaults (except nproc_per_node)
torchrun_defaults = {
'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this is being used as a string?

Copy link
Contributor Author

@szaher szaher Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RobotSail pytorch accepts values for nproc_per_node like auto, gpu, cpu ... etc

Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88

This is also what kubeflow trainer would inject by default (auto)
the list of values kubeflow trainer supports are https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torch.go#L68-L72 + any int value provided by the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dependent on PR instructlab/training#661

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you for explaining that! Seems like we should also update osft.py and its equivalent interface in api_train.py as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if auto is set check the accelerator is available and is cuda then set it to gpu for the backend


# Torchrun parameters for multi-node support
nproc_per_node: int | None = None,
nproc_per_node: str | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you already know the valid string that this could be, I recommend being explicit rather than allowing any string type:

Suggested change
nproc_per_node: str | None = None,
nproc_per_node: literal["gpu"] | None = None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nproc_per_node can be string or int and the string values can be ["gpu", "auto", "xpu", "cpu", ...etc.] or int values. If we use literal it will always expect a fixed list of values (in this case only gpu)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/training_hub/algorithms/osft.py (1)

124-124: Clarify docstring to mention valid string value.

The docstring doesn't explain that nproc_per_node can accept the string "gpu" for automatic GPU detection. This could confuse users about the valid string values.

Update the docstring to clarify valid values:

-            nproc_per_node (str): Number of processes (GPUs) per node for distributed training.
+            nproc_per_node (Literal["gpu"] | int | None): Number of processes per node for distributed training. 
+                Use "gpu" for automatic detection or an integer to specify explicitly.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fca6305 and 69c6ec3.

📒 Files selected for processing (2)
  • src/training_hub/algorithms/osft.py (4 hunks)
  • src/training_hub/algorithms/sft.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/training_hub/algorithms/sft.py

@szaher
Copy link
Contributor Author

szaher commented Oct 6, 2025

@RobotSail @Maxusmusti do we want to only support gpu for this repo? it won't support other values like CPU or XPU ...etc?

@RobotSail
Copy link
Collaborator

@szaher It's backend-dependent. The current backends we're using here for SFT and OSFT do not support either of those, so they cannot be used. If other backends are added in the future which do support those options, then they can have a different implementation.

@Maxusmusti
Copy link
Contributor

Ill also add that it is somewhat algorithm dependent as well. For example, when adding lora/qlora, cpu options make sense. But for full fine-tuning and some reinforcement learning methods, it may be infeasible to run on cpu with any decently sized models

# Set up torchrun arguments with single-node defaults (except nproc_per_node)
torchrun_defaults = {
'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")),
'nnodes': int(os.getenv("WORLD_SIZE", os.getenv("PET_NNODES", "1"))),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use PET_NNODES only

torchrun_defaults = {
'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")),
'nnodes': int(os.getenv("WORLD_SIZE", os.getenv("PET_NNODES", "1"))),
'node_rank': int(os.getenv("PET_NODE_RANK", os.getenv("RANK", "0"))),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop RANK and keep only PET_NODE_RANK

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 69c6ec3 and e309cf7.

📒 Files selected for processing (3)
  • src/training_hub/algorithms/osft.py (7 hunks)
  • src/training_hub/algorithms/sft.py (5 hunks)
  • src/training_hub/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/osft.py (1)
src/training_hub/utils.py (1)
  • get_torchrun_params (34-101)
src/training_hub/algorithms/sft.py (2)
src/training_hub/algorithms/__init__.py (3)
  • Algorithm (6-22)
  • Backend (24-30)
  • AlgorithmRegistry (33-79)
src/training_hub/utils.py (1)
  • get_torchrun_params (34-101)
🪛 Ruff (0.14.0)
src/training_hub/utils.py

58-58: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


62-62: Prefer TypeError exception for invalid type

(TRY004)


62-62: Avoid specifying long messages outside the exception class

(TRY003)


99-99: Avoid specifying long messages outside the exception class

(TRY003)

RobotSail and others added 2 commits October 14, 2025 02:54
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
src/training_hub/utils.py (3)

94-115: Tighten nproc_per_node validation (type and value).

  • Raise TypeError for wrong types.
  • Reject non-positive integers.

Apply this diff:

 def validate_nproc_per_node(value):
     """Validate and normalize nproc_per_node."""
-    if not isinstance(value, (int, str)):
-        raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got type {type(value).__name__}")
+    if not isinstance(value, (int, str)):
+        raise TypeError("nproc_per_node must be 'auto', 'gpu', or an integer")
     if isinstance(value, int):
-        return value
+        if value < 1:
+            raise ValueError("nproc_per_node must be >= 1")
+        return value
@@
-    if value_lower not in ['auto', 'gpu'] and not value_lower.isdigit():
-        raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value!r}")
+    if value_lower not in ['auto', 'gpu'] and not value_lower.isdigit():
+        raise ValueError("nproc_per_node must be 'auto', 'gpu', or a positive integer")

58-64: Optional: add fallbacks to standard torchrun envs (LOCAL_WORLD_SIZE, WORLD_SIZE, NODE_RANK).

Today only PET_* is read (except master vars). If the goal is compatibility with common launchers (kubeflow/torchrun), consider:

  • nproc_per_node ← PET_NPROC_PER_NODE or LOCAL_WORLD_SIZE
  • node_rank ← PET_NODE_RANK or NODE_RANK
  • nnodes ← PET_NNODES, else compute WORLD_SIZE/LOCAL_WORLD_SIZE

Would you like to support these fallbacks? Example:

 def get_env_value(param_name):
     """Get environment variable value with fallback logic."""
-    if param_name in ['master_addr', 'master_port']:
+    if param_name in ['master_addr', 'master_port']:
         # try both PET_ and non-PET_ versions
         return os.getenv(f'PET_{param_name.upper()}') or os.getenv(param_name.upper())
-    return os.getenv(f'PET_{param_name.upper()}')
+    if param_name == 'nproc_per_node':
+        return os.getenv('PET_NPROC_PER_NODE') or os.getenv('LOCAL_WORLD_SIZE')
+    if param_name == 'node_rank':
+        return os.getenv('PET_NODE_RANK') or os.getenv('NODE_RANK')
+    if param_name == 'nnodes':
+        pet = os.getenv('PET_NNODES')
+        if pet:
+            return pet
+        ws, lws = os.getenv('WORLD_SIZE'), os.getenv('LOCAL_WORLD_SIZE')
+        if ws and lws:
+            try:
+                return str(int(int(ws) // int(lws)))
+            except ValueError:
+                return None
+    return os.getenv(f'PET_{param_name.upper()}')

Note: Keep PET_* precedence if that’s your policy.


171-175: Set stacklevel on warnings for correct caller context.

Add stacklevel=2 to warnings.warn(...).

Apply this diff:

-            warnings.warn(
+            warnings.warn(
                 f"Both {master_addr_ref}={master_addr_val!r} and {rdzv_endpoint_ref}={rdzv_endpoint_val!r} are set. "
                 f"Using {master_addr_ref} due to higher precedence. Ignoring {rdzv_endpoint_ref}.",
-                UserWarning
+                UserWarning,
+                stacklevel=2,
             )
@@
-            warnings.warn(
+            warnings.warn(
                 f"Both {rdzv_endpoint_ref}={rdzv_endpoint_val!r} and {master_addr_ref}={master_addr_val!r} are set. "
                 f"Using {rdzv_endpoint_ref} due to higher precedence. Ignoring {master_addr_ref}.",
-                UserWarning
+                UserWarning,
+                stacklevel=2,
             )

Based on static analysis hints.

Also applies to: 179-183

src/training_hub/algorithms/sft.py (1)

1-1: Remove unused import.

os is not used in this module.

Apply this diff:

-import os
src/training_hub/algorithms/osft.py (1)

241-262: Literal type checking isn’t enforced.

_check_type treats unknown typing constructs as valid, so nproc_per_node: Literal['auto','gpu'] | int won’t reject bad strings. Handle Literal explicitly.

Apply this diff:

 from typing import Literal, get_origin, get_args, Union
+import typing
@@
         # Handle Union types (including X | None syntax)
         if origin is Union:
             return any(self._check_type(value, arg) for arg in args)
+
+        # Handle Literal types
+        if origin is typing.Literal:
+            return any(value == lit for lit in args)

This keeps invalid strings from slipping through before backend validation.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e309cf7 and c7a7bcd.

📒 Files selected for processing (3)
  • src/training_hub/algorithms/osft.py (9 hunks)
  • src/training_hub/algorithms/sft.py (9 hunks)
  • src/training_hub/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/sft.py (1)
src/training_hub/utils.py (1)
  • get_torchrun_params (32-207)
src/training_hub/algorithms/osft.py (1)
src/training_hub/utils.py (1)
  • get_torchrun_params (32-207)
🪛 Ruff (0.14.0)
src/training_hub/utils.py

75-78: Avoid specifying long messages outside the exception class

(TRY003)


86-86: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


97-97: Prefer TypeError exception for invalid type

(TRY004)


97-97: Avoid specifying long messages outside the exception class

(TRY003)


103-103: Avoid specifying long messages outside the exception class

(TRY003)


114-114: Avoid specifying long messages outside the exception class

(TRY003)


164-168: Avoid specifying long messages outside the exception class

(TRY003)


171-171: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


179-179: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🔇 Additional comments (1)
src/training_hub/algorithms/osft.py (1)

126-134: Keep docstring consistent with type change.

Ensure the master_port description stays as int.

No code change needed if you apply the signature fix above; just confirm docs align.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/training_hub/utils.py (1)

187-187: Add stacklevel=2 to warnings.

The warning calls lack an explicit stacklevel argument, so the warning will point to this helper function rather than the caller's code, making it harder for users to identify where the conflict originated.

Apply this diff:

-            warnings.warn(
+            warnings.warn(
                 f"Both {master_addr_ref}={master_addr_val!r} and {rdzv_endpoint_ref}={rdzv_endpoint_val!r} are set. "
                 f"Using {master_addr_ref} due to higher precedence. Ignoring {rdzv_endpoint_ref}.",
-                UserWarning
+                UserWarning,
+                stacklevel=2
             )

Also applies to: 195-195

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c7a7bcd and 1dddebd.

📒 Files selected for processing (3)
  • src/training_hub/algorithms/osft.py (9 hunks)
  • src/training_hub/algorithms/sft.py (10 hunks)
  • src/training_hub/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/osft.py (2)
src/training_hub/algorithms/__init__.py (1)
  • Algorithm (6-22)
src/training_hub/utils.py (1)
  • get_torchrun_params (34-228)
src/training_hub/algorithms/sft.py (1)
src/training_hub/utils.py (1)
  • get_torchrun_params (34-228)
🪛 Ruff (0.14.0)
src/training_hub/utils.py

74-77: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Prefer TypeError exception for invalid type

(TRY004)


104-104: Avoid specifying long messages outside the exception class

(TRY003)


110-110: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Avoid specifying long messages outside the exception class

(TRY003)


155-155: Avoid specifying long messages outside the exception class

(TRY003)


180-184: Avoid specifying long messages outside the exception class

(TRY003)


187-187: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


195-195: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


221-221: Avoid specifying long messages outside the exception class

(TRY003)

@RobotSail RobotSail merged commit e324a75 into Red-Hat-AI-Innovation-Team:main Oct 14, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants