Skip to content

Commit e324a75

Browse files
authored
Merge pull request #13 from szaher/torchrun-args-from-env
feat(traininghub): Use torchrun environment variables for default configuration
2 parents e4e1b55 + 1dddebd commit e324a75

File tree

3 files changed

+261
-49
lines changed

3 files changed

+261
-49
lines changed

src/training_hub/algorithms/osft.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
2-
from typing import get_origin, get_args, Union
2+
from typing import Literal, get_origin, get_args, Union
33
from dataclasses import fields
44

55
import datasets
66
from training_hub.algorithms import Algorithm, Backend, AlgorithmRegistry
7-
from training_hub.utils import format_type_name
7+
from training_hub.utils import format_type_name, get_torchrun_params
88

99

1010
class OSFTAlgorithm(Algorithm):
@@ -58,11 +58,13 @@ def train(
5858
data_output_dir: str | None = None,
5959

6060
# Torchrun parameters for multi-node support
61-
nproc_per_node: int | None = None,
61+
nproc_per_node: Literal['auto', 'gpu'] | int | None = None,
6262
nnodes: int | None = None,
6363
node_rank: int | None = None,
64-
rdzv_id: int | None = None,
64+
rdzv_id: str | int | None = None,
6565
rdzv_endpoint: str | None = None,
66+
master_addr: str | None = None,
67+
master_port: int | None = None,
6668
**kwargs,
6769
) -> any:
6870
"""
@@ -121,11 +123,14 @@ def train(
121123
Directory where outputs from data processing will be saved such as intermediate
122124
files. When not provided, it defaults to `_internal_data_processing` under the
123125
`ckpt_output_dir`.
124-
nproc_per_node (int): Number of processes (GPUs) per node for distributed training.
126+
nproc_per_node (Literal['auto', 'gpu'] | int): Number of processes (GPUs) per node for distributed training.
125127
nnodes (int): Total number of nodes for distributed training.
126128
node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
127-
rdzv_id (int): Unique job ID for rendezvous in distributed training.
129+
rdzv_id (str | int): Unique job ID for rendezvous in distributed training.
128130
rdzv_endpoint (str): Master node endpoint for multi-node training.
131+
master_addr (str): Master node address for distributed training (only used with
132+
static rdzv_backend).
133+
master_port (int): Master node port for distributed training.
129134
**kwargs: Additional parameters passed to the backend.
130135
131136
Returns:
@@ -176,6 +181,8 @@ def train(
176181
'node_rank': node_rank,
177182
'rdzv_id': rdzv_id,
178183
'rdzv_endpoint': rdzv_endpoint,
184+
'master_addr': master_addr,
185+
'master_port': master_port,
179186
}
180187

181188
# now do validation now that we've set everything up
@@ -222,11 +229,13 @@ def get_optional_params(self) -> dict[str, type]:
222229
'use_processed_dataset': bool,
223230
'unmask_messages': bool,
224231
'data_output_dir': str,
225-
'nproc_per_node': int,
232+
'nproc_per_node': Literal['auto', 'gpu'] | int,
226233
'nnodes': int,
227234
'node_rank': int,
228-
'rdzv_id': int,
235+
'rdzv_id': str | int,
229236
'rdzv_endpoint': str,
237+
'master_addr': str,
238+
'master_port': int,
230239
}
231240

232241
def _validate_param_types(self, params: dict[str, any]):
@@ -333,6 +342,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
333342
# Rename parameters before sending to backend
334343
algorithm_params = {renames.get(k, k): v for k, v in algorithm_params.items()}
335344

345+
# Separate parameters into their respective dataclass fields
346+
torchrun_args_fields = {f.name for f in fields(TorchrunArgs)}
347+
training_args_fields = {f.name for f in fields(TrainingArgs)}
348+
349+
350+
# process this up here so we can exit early
351+
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
352+
torchrun_args_pre = get_torchrun_params(torchrun_args_pre)
353+
torch_args = TorchrunArgs(**torchrun_args_pre)
354+
336355
# We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
337356
# proceessing. But we do not want to make assumptions about the size of training data or the
338357
# amount of memory on the host. So by default we write to storage, but expose this as a separate
@@ -353,11 +372,6 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
353372
unmask_messages=algorithm_params.get('unmask_messages', False),
354373
)
355374

356-
357-
# Separate parameters into their respective dataclass fields
358-
torchrun_args_fields = {f.name for f in fields(TorchrunArgs)}
359-
training_args_fields = {f.name for f in fields(TrainingArgs)}
360-
361375
# adjust arguments to align with the API definition
362376
training_args_pre = {k: v for k, v in algorithm_params.items() if k in training_args_fields and v is not None}
363377
training_args_pre['data_path'] = training_ready_data_path # replaces raw data path with processed
@@ -372,14 +386,9 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
372386
# but default it to True
373387
training_args_pre['osft'] = training_args_pre.get('osft', True)
374388

375-
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
376-
# TODO: update this default in mini-trainer
377-
torchrun_args_pre['rdzv_endpoint'] = torchrun_args_pre.get('rdzv_endpoint', 'localhost:1738')
378-
379-
380389
# now we run training
381390
return run_training(
382-
torch_args=TorchrunArgs(**torchrun_args_pre),
391+
torch_args=torch_args,
383392
train_args=TrainingArgs(**training_args_pre),
384393
)
385394

@@ -460,11 +469,13 @@ def osft(
460469
save_final_checkpoint: bool | None = None,
461470
num_epochs: int | None = None,
462471
# Torchrun parameters for multi-node support
463-
nproc_per_node: int | None = None,
472+
nproc_per_node: Literal['auto', 'gpu'] | int | None = None,
464473
nnodes: int | None = None,
465474
node_rank: int | None = None,
466-
rdzv_id: int | None = None,
475+
rdzv_id: str | int | None = None,
467476
rdzv_endpoint: str | None = None,
477+
master_port: int | None = None,
478+
master_addr: str | None = None,
468479
**kwargs
469480
) -> any:
470481
from . import create_algorithm
@@ -496,5 +507,7 @@ def osft(
496507
node_rank=node_rank,
497508
rdzv_id=rdzv_id,
498509
rdzv_endpoint=rdzv_endpoint,
510+
master_port=master_port,
511+
master_addr=master_addr,
499512
**kwargs
500513
)

src/training_hub/algorithms/sft.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from instructlab.training import run_training, TorchrunArgs, TrainingArgs
33

44
from . import Algorithm, Backend, AlgorithmRegistry
5+
from training_hub import utils
56

67

78
class InstructLabTrainingSFTBackend(Backend):
@@ -10,7 +11,7 @@ class InstructLabTrainingSFTBackend(Backend):
1011
def execute_training(self, algorithm_params: Dict[str, Any]) -> Any:
1112
"""Execute SFT training using instructlab-training."""
1213
# Separate torchrun parameters from training parameters
13-
torchrun_keys = {'nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint'}
14+
torchrun_keys = {'nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint', 'master_addr', 'master_port'}
1415

1516
# Extract torchrun parameters
1617
torchrun_params = {k: v for k, v in algorithm_params.items() if k in torchrun_keys}
@@ -26,26 +27,9 @@ def execute_training(self, algorithm_params: Dict[str, Any]) -> Any:
2627
training_args = TrainingArgs(**training_params)
2728

2829
# Set up torchrun arguments with single-node defaults (except nproc_per_node)
29-
if torchrun_params:
30-
torchrun_defaults = {
31-
'nnodes': 1,
32-
'node_rank': 0,
33-
'rdzv_id': 0,
34-
'rdzv_endpoint': ""
35-
}
36-
# Merge provided params with defaults
37-
final_torchrun_params = {**torchrun_defaults, **torchrun_params}
38-
torchrun_args = TorchrunArgs(**final_torchrun_params)
39-
else:
40-
# Use single-node defaults including nproc_per_node
41-
torchrun_args = TorchrunArgs(
42-
nproc_per_node=1,
43-
nnodes=1,
44-
node_rank=0,
45-
rdzv_id=0,
46-
rdzv_endpoint=""
47-
)
48-
30+
final_torchrun_params = utils.get_torchrun_params(torchrun_params)
31+
torchrun_args = TorchrunArgs(**final_torchrun_params)
32+
4933
# Execute training
5034
return run_training(
5135
torch_args=torchrun_args,
@@ -76,11 +60,13 @@ def train(self,
7660
accelerate_full_state_at_epoch: Optional[bool] = None,
7761
checkpoint_at_epoch: Optional[bool] = None,
7862
# Torchrun parameters for multi-node support
79-
nproc_per_node: Optional[int] = None,
63+
nproc_per_node: Optional[str | int] = None,
8064
nnodes: Optional[int] = None,
8165
node_rank: Optional[int] = None,
82-
rdzv_id: Optional[int] = None,
66+
rdzv_id: Optional[str | int] = None,
8367
rdzv_endpoint: Optional[str] = None,
68+
master_addr: Optional[str] = None,
69+
master_port: Optional[int] = None,
8470
**kwargs) -> Any:
8571
"""Execute SFT training.
8672
@@ -103,6 +89,8 @@ def train(self,
10389
node_rank: Rank of this node (0 to nnodes-1)
10490
rdzv_id: Unique job ID for rendezvous
10591
rdzv_endpoint: Master node endpoint for multi-node training
92+
master_addr: Master node address for distributed training
93+
master_port: Master node port for distributed training
10694
**kwargs: Additional parameters passed to the backend
10795
10896
Returns:
@@ -128,6 +116,8 @@ def train(self,
128116
'node_rank': node_rank,
129117
'rdzv_id': rdzv_id,
130118
'rdzv_endpoint': rdzv_endpoint,
119+
'master_addr': master_addr,
120+
'master_port': master_port,
131121
}
132122

133123
# Only add non-None parameters (let TrainingArgs handle defaults)
@@ -161,11 +151,13 @@ def get_optional_params(self) -> Dict[str, Type]:
161151
'warmup_steps': int,
162152
'accelerate_full_state_at_epoch': bool,
163153
'checkpoint_at_epoch': bool,
164-
'nproc_per_node': int,
154+
'nproc_per_node': str | int,
165155
'nnodes': int,
166156
'node_rank': int,
167-
'rdzv_id': int,
157+
'rdzv_id': str | int,
168158
'rdzv_endpoint': str,
159+
'master_addr': str,
160+
'master_port': int,
169161
}
170162

171163

@@ -190,11 +182,13 @@ def sft(model_path: str,
190182
accelerate_full_state_at_epoch: Optional[bool] = None,
191183
checkpoint_at_epoch: Optional[bool] = None,
192184
# Torchrun parameters for multi-node support
193-
nproc_per_node: Optional[int] = None,
185+
nproc_per_node: Optional[str | int] = None,
194186
nnodes: Optional[int] = None,
195187
node_rank: Optional[int] = None,
196-
rdzv_id: Optional[int] = None,
188+
rdzv_id: Optional[str | int] = None,
197189
rdzv_endpoint: Optional[str] = None,
190+
master_addr: Optional[str] = None,
191+
master_port: Optional[int] = None,
198192
**kwargs) -> Any:
199193
"""Convenience function to run SFT training.
200194
@@ -218,6 +212,9 @@ def sft(model_path: str,
218212
node_rank: Rank of this node (0 to nnodes-1) for distributed training
219213
rdzv_id: Unique job ID for rendezvous in distributed training
220214
rdzv_endpoint: Master node endpoint for multi-node training
215+
master_addr: Master node address for distributed training
216+
master_port: Master node port for distributed training
217+
221218
**kwargs: Additional parameters passed to the backend
222219
223220
Returns:
@@ -245,6 +242,8 @@ def sft(model_path: str,
245242
node_rank=node_rank,
246243
rdzv_id=rdzv_id,
247244
rdzv_endpoint=rdzv_endpoint,
245+
master_addr=master_addr,
246+
master_port=master_port,
248247
**kwargs
249248
)
250249

0 commit comments

Comments
 (0)