Skip to content

Commit c7a7bcd

Browse files
authored
Merge pull request #1 from RobotSail/advanced-torchrun-parser
adds hierarchical priority, handles edge cases, surface warnings and …
2 parents e309cf7 + c0815f9 commit c7a7bcd

File tree

3 files changed

+201
-83
lines changed

3 files changed

+201
-83
lines changed

src/training_hub/algorithms/osft.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import get_origin, get_args, Union
2+
from typing import Literal, get_origin, get_args, Union
33
from dataclasses import fields
44

55
import datasets
@@ -58,10 +58,10 @@ def train(
5858
data_output_dir: str | None = None,
5959

6060
# Torchrun parameters for multi-node support
61-
nproc_per_node: str | int | None = None,
61+
nproc_per_node: Literal['auto', 'gpu'] | int | None = None,
6262
nnodes: int | None = None,
6363
node_rank: int | None = None,
64-
rdzv_id: str | None = None,
64+
rdzv_id: str | int | None = None,
6565
rdzv_endpoint: str | None = None,
6666
master_addr: str | None = None,
6767
master_port: str | None = None,
@@ -123,14 +123,14 @@ def train(
123123
Directory where outputs from data processing will be saved such as intermediate
124124
files. When not provided, it defaults to `_internal_data_processing` under the
125125
`ckpt_output_dir`.
126-
nproc_per_node (str): Number of processes (GPUs) per node for distributed training.
126+
nproc_per_node (Literal['auto', 'gpu'] | int): Number of processes (GPUs) per node for distributed training.
127127
nnodes (int): Total number of nodes for distributed training.
128128
node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
129-
rdzv_id (str): Unique job ID for rendezvous in distributed training.
129+
rdzv_id (str | int): Unique job ID for rendezvous in distributed training.
130130
rdzv_endpoint (str): Master node endpoint for multi-node training.
131131
master_addr (str): Master node address for distributed training (only used with
132132
static rdzv_backend).
133-
master_port (str): Master node port for distributed training.
133+
master_port (int): Master node port for distributed training.
134134
**kwargs: Additional parameters passed to the backend.
135135
136136
Returns:
@@ -229,10 +229,10 @@ def get_optional_params(self) -> dict[str, type]:
229229
'use_processed_dataset': bool,
230230
'unmask_messages': bool,
231231
'data_output_dir': str,
232-
'nproc_per_node': str | int,
232+
'nproc_per_node': Literal['auto', 'gpu'] | int,
233233
'nnodes': int,
234234
'node_rank': int,
235-
'rdzv_id': str,
235+
'rdzv_id': str | int,
236236
'rdzv_endpoint': str,
237237
'master_addr': str,
238238
'master_port': int,
@@ -342,6 +342,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
342342
# Rename parameters before sending to backend
343343
algorithm_params = {renames.get(k, k): v for k, v in algorithm_params.items()}
344344

345+
# Separate parameters into their respective dataclass fields
346+
torchrun_args_fields = {f.name for f in fields(TorchrunArgs)}
347+
training_args_fields = {f.name for f in fields(TrainingArgs)}
348+
349+
350+
# process this up here so we can exit early
351+
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
352+
torchrun_args_pre = get_torchrun_params(torchrun_args_pre)
353+
torch_args = TorchrunArgs(**torchrun_args_pre)
354+
345355
# We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
346356
# proceessing. But we do not want to make assumptions about the size of training data or the
347357
# amount of memory on the host. So by default we write to storage, but expose this as a separate
@@ -362,11 +372,6 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
362372
unmask_messages=algorithm_params.get('unmask_messages', False),
363373
)
364374

365-
366-
# Separate parameters into their respective dataclass fields
367-
torchrun_args_fields = {f.name for f in fields(TorchrunArgs)}
368-
training_args_fields = {f.name for f in fields(TrainingArgs)}
369-
370375
# adjust arguments to align with the API definition
371376
training_args_pre = {k: v for k, v in algorithm_params.items() if k in training_args_fields and v is not None}
372377
training_args_pre['data_path'] = training_ready_data_path # replaces raw data path with processed
@@ -381,13 +386,9 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
381386
# but default it to True
382387
training_args_pre['osft'] = training_args_pre.get('osft', True)
383388

384-
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
385-
torchrun_args_pre = get_torchrun_params(torchrun_params=torchrun_args_pre)
386-
387-
388389
# now we run training
389390
return run_training(
390-
torch_args=TorchrunArgs(**torchrun_args_pre),
391+
torch_args=torch_args,
391392
train_args=TrainingArgs(**training_args_pre),
392393
)
393394

@@ -468,11 +469,13 @@ def osft(
468469
save_final_checkpoint: bool | None = None,
469470
num_epochs: int | None = None,
470471
# Torchrun parameters for multi-node support
471-
nproc_per_node: str | int | None = None,
472+
nproc_per_node: Literal['auto', 'gpu'] | int | None = None,
472473
nnodes: int | None = None,
473474
node_rank: int | None = None,
474-
rdzv_id: str | None = None,
475+
rdzv_id: str | int | None = None,
475476
rdzv_endpoint: str | None = None,
477+
master_port: str | None = None,
478+
master_addr: str | None = None,
476479
**kwargs
477480
) -> any:
478481
from . import create_algorithm
@@ -504,5 +507,7 @@ def osft(
504507
node_rank=node_rank,
505508
rdzv_id=rdzv_id,
506509
rdzv_endpoint=rdzv_endpoint,
510+
master_port=master_port,
511+
master_addr=master_addr,
507512
**kwargs
508513
)

src/training_hub/algorithms/sft.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class InstructLabTrainingSFTBackend(Backend):
1212
def execute_training(self, algorithm_params: Dict[str, Any]) -> Any:
1313
"""Execute SFT training using instructlab-training."""
1414
# Separate torchrun parameters from training parameters
15-
torchrun_keys = {'nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint'}
15+
torchrun_keys = {'nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint', 'master_addr', 'master_port'}
1616

1717
# Extract torchrun parameters
1818
torchrun_params = {k: v for k, v in algorithm_params.items() if k in torchrun_keys}
@@ -28,14 +28,9 @@ def execute_training(self, algorithm_params: Dict[str, Any]) -> Any:
2828
training_args = TrainingArgs(**training_params)
2929

3030
# Set up torchrun arguments with single-node defaults (except nproc_per_node)
31-
final_torchrun_params = utils.get_torchrun_params(training_args.dict())
31+
final_torchrun_params = utils.get_torchrun_params(torchrun_params)
32+
torchrun_args = TorchrunArgs(**final_torchrun_params)
3233

33-
if torchrun_params:
34-
torchrun_args = TorchrunArgs(**final_torchrun_params)
35-
else:
36-
# Use single-node defaults including nproc_per_node
37-
torchrun_args = TorchrunArgs(**final_torchrun_params)
38-
3934
# Execute training
4035
return run_training(
4136
torch_args=torchrun_args,
@@ -71,6 +66,8 @@ def train(self,
7166
node_rank: Optional[int] = None,
7267
rdzv_id: Optional[str | int] = None,
7368
rdzv_endpoint: Optional[str] = None,
69+
master_addr: Optional[str] = None,
70+
master_port: Optional[int] = None,
7471
**kwargs) -> Any:
7572
"""Execute SFT training.
7673
@@ -93,6 +90,8 @@ def train(self,
9390
node_rank: Rank of this node (0 to nnodes-1)
9491
rdzv_id: Unique job ID for rendezvous
9592
rdzv_endpoint: Master node endpoint for multi-node training
93+
master_addr: Master node address for distributed training
94+
master_port: Master node port for distributed training
9695
**kwargs: Additional parameters passed to the backend
9796
9897
Returns:
@@ -118,6 +117,8 @@ def train(self,
118117
'node_rank': node_rank,
119118
'rdzv_id': rdzv_id,
120119
'rdzv_endpoint': rdzv_endpoint,
120+
'master_addr': master_addr,
121+
'master_port': master_port,
121122
}
122123

123124
# Only add non-None parameters (let TrainingArgs handle defaults)
@@ -156,6 +157,8 @@ def get_optional_params(self) -> Dict[str, Type]:
156157
'node_rank': int,
157158
'rdzv_id': str | int,
158159
'rdzv_endpoint': str,
160+
'master_addr': str,
161+
'master_port': int,
159162
}
160163

161164

@@ -185,6 +188,8 @@ def sft(model_path: str,
185188
node_rank: Optional[int] = None,
186189
rdzv_id: Optional[str | int] = None,
187190
rdzv_endpoint: Optional[str] = None,
191+
master_addr: Optional[str] = None,
192+
master_port: Optional[int] = None,
188193
**kwargs) -> Any:
189194
"""Convenience function to run SFT training.
190195
@@ -235,6 +240,8 @@ def sft(model_path: str,
235240
node_rank=node_rank,
236241
rdzv_id=rdzv_id,
237242
rdzv_endpoint=rdzv_endpoint,
243+
master_addr=master_addr,
244+
master_port=master_port,
238245
**kwargs
239246
)
240247

0 commit comments

Comments
 (0)