Skip to content

Commit e309cf7

Browse files
committed
validate torchrunargs for backends
Signed-off-by: Saad Zaher <[email protected]>
1 parent 69c6ec3 commit e309cf7

File tree

3 files changed

+87
-13
lines changed

3 files changed

+87
-13
lines changed

src/training_hub/algorithms/osft.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import datasets
66
from training_hub.algorithms import Algorithm, Backend, AlgorithmRegistry
7-
from training_hub.utils import format_type_name
7+
from training_hub.utils import format_type_name, get_torchrun_params
88

99

1010
class OSFTAlgorithm(Algorithm):
@@ -63,6 +63,8 @@ def train(
6363
node_rank: int | None = None,
6464
rdzv_id: str | None = None,
6565
rdzv_endpoint: str | None = None,
66+
master_addr: str | None = None,
67+
master_port: str | None = None,
6668
**kwargs,
6769
) -> any:
6870
"""
@@ -126,6 +128,9 @@ def train(
126128
node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
127129
rdzv_id (str): Unique job ID for rendezvous in distributed training.
128130
rdzv_endpoint (str): Master node endpoint for multi-node training.
131+
master_addr (str): Master node address for distributed training (only used with
132+
static rdzv_backend).
133+
master_port (str): Master node port for distributed training.
129134
**kwargs: Additional parameters passed to the backend.
130135
131136
Returns:
@@ -176,6 +181,8 @@ def train(
176181
'node_rank': node_rank,
177182
'rdzv_id': rdzv_id,
178183
'rdzv_endpoint': rdzv_endpoint,
184+
'master_addr': master_addr,
185+
'master_port': master_port,
179186
}
180187

181188
# now do validation now that we've set everything up
@@ -227,6 +234,8 @@ def get_optional_params(self) -> dict[str, type]:
227234
'node_rank': int,
228235
'rdzv_id': str,
229236
'rdzv_endpoint': str,
237+
'master_addr': str,
238+
'master_port': int,
230239
}
231240

232241
def _validate_param_types(self, params: dict[str, any]):
@@ -373,8 +382,7 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
373382
training_args_pre['osft'] = training_args_pre.get('osft', True)
374383

375384
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
376-
# TODO: update this default in mini-trainer
377-
torchrun_args_pre['rdzv_endpoint'] = torchrun_args_pre.get('rdzv_endpoint', 'localhost:1738')
385+
torchrun_args_pre = get_torchrun_params(torchrun_params=torchrun_args_pre)
378386

379387

380388
# now we run training

src/training_hub/algorithms/sft.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from instructlab.training import run_training, TorchrunArgs, TrainingArgs
44

55
from . import Algorithm, Backend, AlgorithmRegistry
6+
from training_hub import utils
67

78

89
class InstructLabTrainingSFTBackend(Backend):
@@ -27,21 +28,13 @@ def execute_training(self, algorithm_params: Dict[str, Any]) -> Any:
2728
training_args = TrainingArgs(**training_params)
2829

2930
# Set up torchrun arguments with single-node defaults (except nproc_per_node)
30-
torchrun_defaults = {
31-
'nproc_per_node': os.getenv("LOCAL_WORLD_SIZE", os.getenv("PET_NPROC_PER_NODE", "1")),
32-
'nnodes': int(os.getenv("WORLD_SIZE", os.getenv("PET_NNODES", "1"))),
33-
'node_rank': int(os.getenv("PET_NODE_RANK", os.getenv("RANK", "0"))),
34-
'rdzv_id': 0,
35-
'rdzv_endpoint': ""
36-
}
31+
final_torchrun_params = utils.get_torchrun_params(training_args.dict())
3732

3833
if torchrun_params:
39-
# Merge provided params with defaults
40-
final_torchrun_params = {**torchrun_defaults, **torchrun_params}
4134
torchrun_args = TorchrunArgs(**final_torchrun_params)
4235
else:
4336
# Use single-node defaults including nproc_per_node
44-
torchrun_args = TorchrunArgs(**torchrun_defaults)
37+
torchrun_args = TorchrunArgs(**final_torchrun_params)
4538

4639
# Execute training
4740
return run_training(

src/training_hub/utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
from curses.ascii import isdigit
3+
from importlib.metadata import pass_none
14
from typing import get_origin, get_args
25

36
def format_type_name(tp):
@@ -26,3 +29,73 @@ def format_type_name(tp):
2629
return type_str[8:-2]
2730

2831
return type_str
32+
33+
34+
def get_torchrun_params(args: dict):
35+
"""
36+
Parse and load PyTorch variables from dict with fallback to environment variables.
37+
38+
Args:
39+
args (dict): Dictionary containing PyTorch configuration parameters
40+
41+
Returns:
42+
dict: Dictionary with PyTorch parameters loaded from args or environment
43+
"""
44+
pytorch_vars = ['nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint', 'master_addr', 'master_port']
45+
torchrun_args = {}
46+
47+
def validate_nproc_per_node(value):
48+
"""Validate and convert nproc_per_node value."""
49+
if isinstance(value, str):
50+
if value.lower() == 'auto':
51+
return 'gpu'
52+
elif value.lower() == 'gpu':
53+
return 'gpu'
54+
else:
55+
try:
56+
return int(value)
57+
except ValueError:
58+
raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value}")
59+
elif isinstance(value, int):
60+
return value
61+
else:
62+
raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value}")
63+
64+
def get_env_var_name(var_name):
65+
"""Get environment variable name based on PyTorch convention."""
66+
return var_name.upper() if var_name in ['master_addr', 'master_port'] else f"PET_{var_name.upper()}"
67+
68+
for var_name in pytorch_vars:
69+
# Try args dict first
70+
if var_name in args and args[var_name] is not None and args[var_name] != "":
71+
value = args[var_name]
72+
if var_name == 'nproc_per_node':
73+
torchrun_args[var_name] = validate_nproc_per_node(value)
74+
elif var_name in ['nnodes', 'node_rank', 'rdzv_id', 'master_port']:
75+
torchrun_args[var_name] = int(value) if isinstance(value, (str, int)) else value
76+
else:
77+
torchrun_args[var_name] = value
78+
else:
79+
# Fallback to environment variable
80+
env_value = os.getenv(get_env_var_name(var_name))
81+
if env_value is not None:
82+
if var_name == 'nproc_per_node':
83+
torchrun_args[var_name] = validate_nproc_per_node(env_value)
84+
elif var_name in ['nnodes', 'node_rank', 'rdzv_id', 'master_port']:
85+
try:
86+
torchrun_args[var_name] = int(env_value)
87+
except ValueError:
88+
torchrun_args[var_name] = env_value
89+
else:
90+
torchrun_args[var_name] = env_value
91+
else:
92+
# Set defaults
93+
defaults = {'nnodes': 1, 'rdzv_id': 0}
94+
torchrun_args[var_name] = defaults.get(var_name, "")
95+
96+
# Validate mutually exclusive parameters
97+
if (torchrun_args.get('rdzv_endpoint', '') != "" and
98+
(torchrun_args.get('master_addr', '') != "" or torchrun_args.get('master_port', '') != "")):
99+
raise ValueError("Cannot specify both rdzv_endpoint and master_addr/master_port. These are mutually exclusive parameters.")
100+
101+
return torchrun_args

0 commit comments

Comments
 (0)