11import os
2- from typing import get_origin , get_args , Union
2+ from typing import Literal , get_origin , get_args , Union
33from dataclasses import fields
44
55import datasets
66from training_hub .algorithms import Algorithm , Backend , AlgorithmRegistry
7- from training_hub .utils import format_type_name
7+ from training_hub .utils import format_type_name , get_torchrun_params
88
99
1010class OSFTAlgorithm (Algorithm ):
@@ -58,11 +58,13 @@ def train(
5858 data_output_dir : str | None = None ,
5959
6060 # Torchrun parameters for multi-node support
61- nproc_per_node : int | None = None ,
61+ nproc_per_node : Literal [ 'auto' , 'gpu' ] | int | None = None ,
6262 nnodes : int | None = None ,
6363 node_rank : int | None = None ,
64- rdzv_id : int | None = None ,
64+ rdzv_id : str | int | None = None ,
6565 rdzv_endpoint : str | None = None ,
66+ master_addr : str | None = None ,
67+ master_port : int | None = None ,
6668 ** kwargs ,
6769 ) -> any :
6870 """
@@ -121,11 +123,14 @@ def train(
121123 Directory where outputs from data processing will be saved such as intermediate
122124 files. When not provided, it defaults to `_internal_data_processing` under the
123125 `ckpt_output_dir`.
124- nproc_per_node (int): Number of processes (GPUs) per node for distributed training.
126+ nproc_per_node (Literal['auto', 'gpu'] | int): Number of processes (GPUs) per node for distributed training.
125127 nnodes (int): Total number of nodes for distributed training.
126128 node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
127- rdzv_id (int): Unique job ID for rendezvous in distributed training.
129+ rdzv_id (str | int): Unique job ID for rendezvous in distributed training.
128130 rdzv_endpoint (str): Master node endpoint for multi-node training.
131+ master_addr (str): Master node address for distributed training (only used with
132+ static rdzv_backend).
133+ master_port (int): Master node port for distributed training.
129134 **kwargs: Additional parameters passed to the backend.
130135
131136 Returns:
@@ -176,6 +181,8 @@ def train(
176181 'node_rank' : node_rank ,
177182 'rdzv_id' : rdzv_id ,
178183 'rdzv_endpoint' : rdzv_endpoint ,
184+ 'master_addr' : master_addr ,
185+ 'master_port' : master_port ,
179186 }
180187
181188 # now do validation now that we've set everything up
@@ -222,11 +229,13 @@ def get_optional_params(self) -> dict[str, type]:
222229 'use_processed_dataset' : bool ,
223230 'unmask_messages' : bool ,
224231 'data_output_dir' : str ,
225- 'nproc_per_node' : int ,
232+ 'nproc_per_node' : Literal [ 'auto' , 'gpu' ] | int ,
226233 'nnodes' : int ,
227234 'node_rank' : int ,
228- 'rdzv_id' : int ,
235+ 'rdzv_id' : str | int ,
229236 'rdzv_endpoint' : str ,
237+ 'master_addr' : str ,
238+ 'master_port' : int ,
230239 }
231240
232241 def _validate_param_types (self , params : dict [str , any ]):
@@ -333,6 +342,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
333342 # Rename parameters before sending to backend
334343 algorithm_params = {renames .get (k , k ): v for k , v in algorithm_params .items ()}
335344
345+ # Separate parameters into their respective dataclass fields
346+ torchrun_args_fields = {f .name for f in fields (TorchrunArgs )}
347+ training_args_fields = {f .name for f in fields (TrainingArgs )}
348+
349+
350+ # process this up here so we can exit early
351+ torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
352+ torchrun_args_pre = get_torchrun_params (torchrun_args_pre )
353+ torch_args = TorchrunArgs (** torchrun_args_pre )
354+
336355 # We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
337356 # proceessing. But we do not want to make assumptions about the size of training data or the
338357 # amount of memory on the host. So by default we write to storage, but expose this as a separate
@@ -353,11 +372,6 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
353372 unmask_messages = algorithm_params .get ('unmask_messages' , False ),
354373 )
355374
356-
357- # Separate parameters into their respective dataclass fields
358- torchrun_args_fields = {f .name for f in fields (TorchrunArgs )}
359- training_args_fields = {f .name for f in fields (TrainingArgs )}
360-
361375 # adjust arguments to align with the API definition
362376 training_args_pre = {k : v for k , v in algorithm_params .items () if k in training_args_fields and v is not None }
363377 training_args_pre ['data_path' ] = training_ready_data_path # replaces raw data path with processed
@@ -372,14 +386,9 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
372386 # but default it to True
373387 training_args_pre ['osft' ] = training_args_pre .get ('osft' , True )
374388
375- torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
376- # TODO: update this default in mini-trainer
377- torchrun_args_pre ['rdzv_endpoint' ] = torchrun_args_pre .get ('rdzv_endpoint' , 'localhost:1738' )
378-
379-
380389 # now we run training
381390 return run_training (
382- torch_args = TorchrunArgs ( ** torchrun_args_pre ) ,
391+ torch_args = torch_args ,
383392 train_args = TrainingArgs (** training_args_pre ),
384393 )
385394
@@ -460,11 +469,13 @@ def osft(
460469 save_final_checkpoint : bool | None = None ,
461470 num_epochs : int | None = None ,
462471 # Torchrun parameters for multi-node support
463- nproc_per_node : int | None = None ,
472+ nproc_per_node : Literal [ 'auto' , 'gpu' ] | int | None = None ,
464473 nnodes : int | None = None ,
465474 node_rank : int | None = None ,
466- rdzv_id : int | None = None ,
475+ rdzv_id : str | int | None = None ,
467476 rdzv_endpoint : str | None = None ,
477+ master_port : int | None = None ,
478+ master_addr : str | None = None ,
468479 ** kwargs
469480) -> any :
470481 from . import create_algorithm
@@ -496,5 +507,7 @@ def osft(
496507 node_rank = node_rank ,
497508 rdzv_id = rdzv_id ,
498509 rdzv_endpoint = rdzv_endpoint ,
510+ master_port = master_port ,
511+ master_addr = master_addr ,
499512 ** kwargs
500513 )
0 commit comments