-
Notifications
You must be signed in to change notification settings - Fork 18
feat(traininghub): Use torchrun environment variables for default configuration #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(traininghub): Use torchrun environment variables for default configuration #13
Conversation
…figuration Problem: The previous implementation assigned default values directly to torchrun arguments. According to PyTorch's behavior, this prevents torchrun from reading its corresponding environment variables. Consequently, configurations provided by orchestration systems like Kubeflow (which inject variables like RANK and WORLD_SIZE) were being ignored. Solution: This commit refactors the argument handling logic to prioritize environment variables. The code now checks for the presence of the most common torchrun environment variables. If an environment variable is set, its value is used for the argument; otherwise, a default is applied. This change aligns Training Hub's behavior with torchrun's intended design, making it compatible with standard distributed training environments. Reference: https://docs.pytorch.org/docs/stable/elastic/run.html#environment-variables Signed-off-by: Saad Zaher <[email protected]>
WalkthroughSFT and OSFT now use a new training_hub.utils.get_torchrun_params helper to resolve and validate torchrun/distributed parameters (nproc_per_node, nnodes, node_rank, rdzv_id, rdzv_endpoint, master_addr, master_port) from args and environment. Public signatures widened to accept string|int unions and master_addr/master_port are propagated through training. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Env as Environment
participant API as SFT/OSFT API
participant Utils as get_torchrun_params()
participant Torchrun as TorchrunArgs
participant Runner as backend / run_training
User->>API: call train(...) (may include nproc_per_node, rdzv_id, master_addr/port)
API->>Utils: get_torchrun_params({torchrun_keys from args/env})
Utils->>Env: read WORLD_SIZE/LOCAL_WORLD_SIZE/RANK, PET_* variants, master and rdzv envs
Env-->>Utils: env values
Utils-->>API: resolved torchrun params (source: args|env), or error on conflicts
API->>API: merge/validate additional user torchrun_params (if provided)
API->>Torchrun: construct TorchrunArgs(final params)
API->>Runner: run_training(torch_args=TorchrunArgs, train_args=...)
Runner-->>User: training job launched / execution started
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Saad Zaher <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RobotSail
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @szaher ! Could you please review the few comments on this PR? Otherwise, this looks good to merge.
Apply suggestion by rabbitcode to use new variables while keeping old variables which are the only ones supported by kubeflow trainer Signed-off-by: Saad Zaher <[email protected]>
src/training_hub/algorithms/sft.py
Outdated
|
|
||
| # Set up torchrun arguments with single-node defaults (except nproc_per_node) | ||
| torchrun_defaults = { | ||
| 'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come this is being used as a string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RobotSail pytorch accepts values for nproc_per_node like auto, gpu, cpu ... etc
Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88
This is also what kubeflow trainer would inject by default (auto)
the list of values kubeflow trainer supports are https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torch.go#L68-L72 + any int value provided by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dependent on PR instructlab/training#661
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thank you for explaining that! Seems like we should also update osft.py and its equivalent interface in api_train.py as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if auto is set check the accelerator is available and is cuda then set it to gpu for the backend
Signed-off-by: Saad Zaher <[email protected]>
src/training_hub/algorithms/osft.py
Outdated
|
|
||
| # Torchrun parameters for multi-node support | ||
| nproc_per_node: int | None = None, | ||
| nproc_per_node: str | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you already know the valid string that this could be, I recommend being explicit rather than allowing any string type:
| nproc_per_node: str | None = None, | |
| nproc_per_node: literal["gpu"] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nproc_per_node can be string or int and the string values can be ["gpu", "auto", "xpu", "cpu", ...etc.] or int values. If we use literal it will always expect a fixed list of values (in this case only gpu)
Signed-off-by: Saad Zaher <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/training_hub/algorithms/osft.py (1)
124-124: Clarify docstring to mention valid string value.The docstring doesn't explain that
nproc_per_nodecan accept the string"gpu"for automatic GPU detection. This could confuse users about the valid string values.Update the docstring to clarify valid values:
- nproc_per_node (str): Number of processes (GPUs) per node for distributed training. + nproc_per_node (Literal["gpu"] | int | None): Number of processes per node for distributed training. + Use "gpu" for automatic detection or an integer to specify explicitly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/training_hub/algorithms/osft.py(4 hunks)src/training_hub/algorithms/sft.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/training_hub/algorithms/sft.py
|
@RobotSail @Maxusmusti do we want to only support |
|
@szaher It's backend-dependent. The current backends we're using here for SFT and OSFT do not support either of those, so they cannot be used. If other backends are added in the future which do support those options, then they can have a different implementation. |
|
Ill also add that it is somewhat algorithm dependent as well. For example, when adding lora/qlora, cpu options make sense. But for full fine-tuning and some reinforcement learning methods, it may be infeasible to run on cpu with any decently sized models |
src/training_hub/algorithms/sft.py
Outdated
| # Set up torchrun arguments with single-node defaults (except nproc_per_node) | ||
| torchrun_defaults = { | ||
| 'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")), | ||
| 'nnodes': int(os.getenv("WORLD_SIZE", os.getenv("PET_NNODES", "1"))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use PET_NNODES only
src/training_hub/algorithms/sft.py
Outdated
| torchrun_defaults = { | ||
| 'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")), | ||
| 'nnodes': int(os.getenv("WORLD_SIZE", os.getenv("PET_NNODES", "1"))), | ||
| 'node_rank': int(os.getenv("PET_NODE_RANK", os.getenv("RANK", "0"))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drop RANK and keep only PET_NODE_RANK
Signed-off-by: Saad Zaher <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/training_hub/algorithms/osft.py(7 hunks)src/training_hub/algorithms/sft.py(5 hunks)src/training_hub/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/osft.py (1)
src/training_hub/utils.py (1)
get_torchrun_params(34-101)
src/training_hub/algorithms/sft.py (2)
src/training_hub/algorithms/__init__.py (3)
Algorithm(6-22)Backend(24-30)AlgorithmRegistry(33-79)src/training_hub/utils.py (1)
get_torchrun_params(34-101)
🪛 Ruff (0.14.0)
src/training_hub/utils.py
58-58: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
62-62: Prefer TypeError exception for invalid type
(TRY004)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
99-99: Avoid specifying long messages outside the exception class
(TRY003)
adds hierarchical priority, handles edge cases, surface warnings and …
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (5)
src/training_hub/utils.py (3)
94-115: Tightennproc_per_nodevalidation (type and value).
- Raise
TypeErrorfor wrong types.- Reject non-positive integers.
Apply this diff:
def validate_nproc_per_node(value): """Validate and normalize nproc_per_node.""" - if not isinstance(value, (int, str)): - raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got type {type(value).__name__}") + if not isinstance(value, (int, str)): + raise TypeError("nproc_per_node must be 'auto', 'gpu', or an integer") if isinstance(value, int): - return value + if value < 1: + raise ValueError("nproc_per_node must be >= 1") + return value @@ - if value_lower not in ['auto', 'gpu'] and not value_lower.isdigit(): - raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value!r}") + if value_lower not in ['auto', 'gpu'] and not value_lower.isdigit(): + raise ValueError("nproc_per_node must be 'auto', 'gpu', or a positive integer")
58-64: Optional: add fallbacks to standard torchrun envs (LOCAL_WORLD_SIZE, WORLD_SIZE, NODE_RANK).Today only PET_* is read (except master vars). If the goal is compatibility with common launchers (kubeflow/torchrun), consider:
- nproc_per_node ← PET_NPROC_PER_NODE or LOCAL_WORLD_SIZE
- node_rank ← PET_NODE_RANK or NODE_RANK
- nnodes ← PET_NNODES, else compute WORLD_SIZE/LOCAL_WORLD_SIZE
Would you like to support these fallbacks? Example:
def get_env_value(param_name): """Get environment variable value with fallback logic.""" - if param_name in ['master_addr', 'master_port']: + if param_name in ['master_addr', 'master_port']: # try both PET_ and non-PET_ versions return os.getenv(f'PET_{param_name.upper()}') or os.getenv(param_name.upper()) - return os.getenv(f'PET_{param_name.upper()}') + if param_name == 'nproc_per_node': + return os.getenv('PET_NPROC_PER_NODE') or os.getenv('LOCAL_WORLD_SIZE') + if param_name == 'node_rank': + return os.getenv('PET_NODE_RANK') or os.getenv('NODE_RANK') + if param_name == 'nnodes': + pet = os.getenv('PET_NNODES') + if pet: + return pet + ws, lws = os.getenv('WORLD_SIZE'), os.getenv('LOCAL_WORLD_SIZE') + if ws and lws: + try: + return str(int(int(ws) // int(lws))) + except ValueError: + return None + return os.getenv(f'PET_{param_name.upper()}')Note: Keep PET_* precedence if that’s your policy.
171-175: Setstacklevelon warnings for correct caller context.Add
stacklevel=2towarnings.warn(...).Apply this diff:
- warnings.warn( + warnings.warn( f"Both {master_addr_ref}={master_addr_val!r} and {rdzv_endpoint_ref}={rdzv_endpoint_val!r} are set. " f"Using {master_addr_ref} due to higher precedence. Ignoring {rdzv_endpoint_ref}.", - UserWarning + UserWarning, + stacklevel=2, ) @@ - warnings.warn( + warnings.warn( f"Both {rdzv_endpoint_ref}={rdzv_endpoint_val!r} and {master_addr_ref}={master_addr_val!r} are set. " f"Using {rdzv_endpoint_ref} due to higher precedence. Ignoring {master_addr_ref}.", - UserWarning + UserWarning, + stacklevel=2, )Based on static analysis hints.
Also applies to: 179-183
src/training_hub/algorithms/sft.py (1)
1-1: Remove unused import.
osis not used in this module.Apply this diff:
-import ossrc/training_hub/algorithms/osft.py (1)
241-262: Literal type checking isn’t enforced.
_check_typetreats unknown typing constructs as valid, sonproc_per_node: Literal['auto','gpu'] | intwon’t reject bad strings. HandleLiteralexplicitly.Apply this diff:
from typing import Literal, get_origin, get_args, Union +import typing @@ # Handle Union types (including X | None syntax) if origin is Union: return any(self._check_type(value, arg) for arg in args) + + # Handle Literal types + if origin is typing.Literal: + return any(value == lit for lit in args)This keeps invalid strings from slipping through before backend validation.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/training_hub/algorithms/osft.py(9 hunks)src/training_hub/algorithms/sft.py(9 hunks)src/training_hub/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/sft.py (1)
src/training_hub/utils.py (1)
get_torchrun_params(32-207)
src/training_hub/algorithms/osft.py (1)
src/training_hub/utils.py (1)
get_torchrun_params(32-207)
🪛 Ruff (0.14.0)
src/training_hub/utils.py
75-78: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Unnecessary key check before dictionary access
Replace with dict.get
(RUF019)
97-97: Prefer TypeError exception for invalid type
(TRY004)
97-97: Avoid specifying long messages outside the exception class
(TRY003)
103-103: Avoid specifying long messages outside the exception class
(TRY003)
114-114: Avoid specifying long messages outside the exception class
(TRY003)
164-168: Avoid specifying long messages outside the exception class
(TRY003)
171-171: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
179-179: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (1)
src/training_hub/algorithms/osft.py (1)
126-134: Keep docstring consistent with type change.Ensure the
master_portdescription stays as int.No code change needed if you apply the signature fix above; just confirm docs align.
… for strings; integers evaluate as-is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/training_hub/utils.py (1)
187-187: Addstacklevel=2to warnings.The warning calls lack an explicit
stacklevelargument, so the warning will point to this helper function rather than the caller's code, making it harder for users to identify where the conflict originated.Apply this diff:
- warnings.warn( + warnings.warn( f"Both {master_addr_ref}={master_addr_val!r} and {rdzv_endpoint_ref}={rdzv_endpoint_val!r} are set. " f"Using {master_addr_ref} due to higher precedence. Ignoring {rdzv_endpoint_ref}.", - UserWarning + UserWarning, + stacklevel=2 )Also applies to: 195-195
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/training_hub/algorithms/osft.py(9 hunks)src/training_hub/algorithms/sft.py(10 hunks)src/training_hub/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/training_hub/algorithms/osft.py (2)
src/training_hub/algorithms/__init__.py (1)
Algorithm(6-22)src/training_hub/utils.py (1)
get_torchrun_params(34-228)
src/training_hub/algorithms/sft.py (1)
src/training_hub/utils.py (1)
get_torchrun_params(34-228)
🪛 Ruff (0.14.0)
src/training_hub/utils.py
74-77: Avoid specifying long messages outside the exception class
(TRY003)
104-104: Prefer TypeError exception for invalid type
(TRY004)
104-104: Avoid specifying long messages outside the exception class
(TRY003)
110-110: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Avoid specifying long messages outside the exception class
(TRY003)
155-155: Avoid specifying long messages outside the exception class
(TRY003)
180-184: Avoid specifying long messages outside the exception class
(TRY003)
187-187: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
195-195: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
221-221: Avoid specifying long messages outside the exception class
(TRY003)
Problem:
The previous implementation assigned default values directly to torchrun arguments. According to PyTorch's behavior, this prevents torchrun from reading its corresponding environment variables. Consequently, configurations provided by orchestration systems like Kubeflow (which inject variables like RANK and WORLD_SIZE) were being ignored.
Solution:
This commit refactors the argument handling logic to prioritize environment variables. The code now checks for the presence of the most common torchrun environment variables. If an environment variable is set, its value is used for the argument; otherwise, a default is applied.
This change aligns Training Hub's behavior with torchrun's intended design, making it compatible with standard distributed training environments.
Reference: https://docs.pytorch.org/docs/stable/elastic/run.html#environment-variables
Summary by CodeRabbit
New Features
Refactor
Breaking Changes