Skip to content

Commit 1dddebd

Browse files
committed
enhancements
1 parent 4a2a4f2 commit 1dddebd

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

src/training_hub/algorithms/sft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import Any, Dict, Type, Optional
32
from instructlab.training import run_training, TorchrunArgs, TrainingArgs
43

@@ -213,6 +212,9 @@ def sft(model_path: str,
213212
node_rank: Rank of this node (0 to nnodes-1) for distributed training
214213
rdzv_id: Unique job ID for rendezvous in distributed training
215214
rdzv_endpoint: Master node endpoint for multi-node training
215+
master_addr: Master node address for distributed training
216+
master_port: Master node port for distributed training
217+
216218
**kwargs: Additional parameters passed to the backend
217219
218220
Returns:

src/training_hub/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def format_type_name(tp):
3131
return type_str
3232

3333

34-
def get_torchrun_params(args: dict):
34+
def get_torchrun_params(args: dict) -> dict[str, str | int]:
3535
"""
3636
Parse and load PyTorch distributed training parameters with hierarchical precedence.
3737
@@ -104,21 +104,18 @@ def validate_nproc_per_node(value: int | str) -> int | str:
104104
raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got type {type(value).__name__}")
105105
if isinstance(value, int):
106106
return value
107-
107+
108108
value_lower = value.lower().strip()
109109
if value_lower not in ['auto', 'gpu'] and not value_lower.isdigit():
110110
raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value!r}")
111111
if value_lower.isdigit():
112112
return int(value_lower)
113-
elif value_lower == 'gpu':
114-
return 'gpu'
115113

116-
# otherwise just handle auto logic
117-
# convert 'auto' to 'gpu' if CUDA is available
118-
if torch.cuda.is_available():
114+
# handle 'auto' and 'gpu' - both require CUDA
115+
if value_lower in ['auto', 'gpu'] and torch.cuda.is_available():
119116
return 'gpu'
120117
else:
121-
raise ValueError("nproc_per_node='auto' requires CUDA GPUs, but none are available")
118+
raise ValueError(f"nproc_per_node='{value_lower}' requires CUDA GPUs, but none are available")
122119

123120
def get_param_reference(param_name: str, source: str) -> str:
124121
"""Format parameter reference based on source (args vs env)."""
@@ -151,7 +148,13 @@ def get_param_reference(param_name: str, source: str) -> str:
151148
# we know the final values in this case must be integers, so any non-None value here
152149
# should be castable to `int`.
153150
value, _ = get_param_value(param)
154-
torchrun_args[param] = int(value) if value is not None else default
151+
if value is not None:
152+
try:
153+
torchrun_args[param] = int(value)
154+
except (ValueError, TypeError) as e:
155+
raise ValueError(f"Invalid value for {param}: {value!r}. Must be an integer.") from e
156+
else:
157+
torchrun_args[param] = default
155158

156159

157160
# rdzv_id will be either a str or int; we just perform some cleanup before
@@ -212,8 +215,13 @@ def get_param_reference(param_name: str, source: str) -> str:
212215
# validate env conflicts only when we're actually using master_port
213216
if master_port_source == 'env':
214217
validate_env_conflict('master_port')
215-
torchrun_args['master_port'] = int(master_port_val)
218+
try:
219+
torchrun_args['master_port'] = int(master_port_val)
220+
except (ValueError, TypeError) as e:
221+
raise ValueError(f"Invalid value for master_port: {master_port_val!r}. Must be an integer.") from e
216222

223+
# Note: If neither master_addr nor rdzv_endpoint is set, torchrun will use
224+
# its default behavior (typically localhost or other configured defaults)
217225
elif rdzv_endpoint_val:
218226
torchrun_args['rdzv_endpoint'] = rdzv_endpoint_val
219227

0 commit comments

Comments
 (0)