11import os
2- from typing import get_origin , get_args , Union
2+ from typing import Literal , get_origin , get_args , Union
33from dataclasses import fields
44
55import datasets
@@ -58,10 +58,10 @@ def train(
5858 data_output_dir : str | None = None ,
5959
6060 # Torchrun parameters for multi-node support
61- nproc_per_node : str | int | None = None ,
61+ nproc_per_node : Literal [ 'auto' , 'gpu' ] | int | None = None ,
6262 nnodes : int | None = None ,
6363 node_rank : int | None = None ,
64- rdzv_id : str | None = None ,
64+ rdzv_id : str | int | None = None ,
6565 rdzv_endpoint : str | None = None ,
6666 master_addr : str | None = None ,
6767 master_port : str | None = None ,
@@ -123,14 +123,14 @@ def train(
123123 Directory where outputs from data processing will be saved such as intermediate
124124 files. When not provided, it defaults to `_internal_data_processing` under the
125125 `ckpt_output_dir`.
126- nproc_per_node (str ): Number of processes (GPUs) per node for distributed training.
126+ nproc_per_node (Literal['auto', 'gpu'] | int ): Number of processes (GPUs) per node for distributed training.
127127 nnodes (int): Total number of nodes for distributed training.
128128 node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
129- rdzv_id (str): Unique job ID for rendezvous in distributed training.
129+ rdzv_id (str | int ): Unique job ID for rendezvous in distributed training.
130130 rdzv_endpoint (str): Master node endpoint for multi-node training.
131131 master_addr (str): Master node address for distributed training (only used with
132132 static rdzv_backend).
133- master_port (str ): Master node port for distributed training.
133+ master_port (int ): Master node port for distributed training.
134134 **kwargs: Additional parameters passed to the backend.
135135
136136 Returns:
@@ -229,10 +229,10 @@ def get_optional_params(self) -> dict[str, type]:
229229 'use_processed_dataset' : bool ,
230230 'unmask_messages' : bool ,
231231 'data_output_dir' : str ,
232- 'nproc_per_node' : str | int ,
232+ 'nproc_per_node' : Literal [ 'auto' , 'gpu' ] | int ,
233233 'nnodes' : int ,
234234 'node_rank' : int ,
235- 'rdzv_id' : str ,
235+ 'rdzv_id' : str | int ,
236236 'rdzv_endpoint' : str ,
237237 'master_addr' : str ,
238238 'master_port' : int ,
@@ -342,6 +342,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
342342 # Rename parameters before sending to backend
343343 algorithm_params = {renames .get (k , k ): v for k , v in algorithm_params .items ()}
344344
345+ # Separate parameters into their respective dataclass fields
346+ torchrun_args_fields = {f .name for f in fields (TorchrunArgs )}
347+ training_args_fields = {f .name for f in fields (TrainingArgs )}
348+
349+
350+ # process this up here so we can exit early
351+ torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
352+ torchrun_args_pre = get_torchrun_params (torchrun_args_pre )
353+ torch_args = TorchrunArgs (** torchrun_args_pre )
354+
345355 # We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
346356 # proceessing. But we do not want to make assumptions about the size of training data or the
347357 # amount of memory on the host. So by default we write to storage, but expose this as a separate
@@ -362,11 +372,6 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
362372 unmask_messages = algorithm_params .get ('unmask_messages' , False ),
363373 )
364374
365-
366- # Separate parameters into their respective dataclass fields
367- torchrun_args_fields = {f .name for f in fields (TorchrunArgs )}
368- training_args_fields = {f .name for f in fields (TrainingArgs )}
369-
370375 # adjust arguments to align with the API definition
371376 training_args_pre = {k : v for k , v in algorithm_params .items () if k in training_args_fields and v is not None }
372377 training_args_pre ['data_path' ] = training_ready_data_path # replaces raw data path with processed
@@ -381,13 +386,9 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
381386 # but default it to True
382387 training_args_pre ['osft' ] = training_args_pre .get ('osft' , True )
383388
384- torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
385- torchrun_args_pre = get_torchrun_params (torchrun_params = torchrun_args_pre )
386-
387-
388389 # now we run training
389390 return run_training (
390- torch_args = TorchrunArgs ( ** torchrun_args_pre ) ,
391+ torch_args = torch_args ,
391392 train_args = TrainingArgs (** training_args_pre ),
392393 )
393394
@@ -468,11 +469,13 @@ def osft(
468469 save_final_checkpoint : bool | None = None ,
469470 num_epochs : int | None = None ,
470471 # Torchrun parameters for multi-node support
471- nproc_per_node : str | int | None = None ,
472+ nproc_per_node : Literal [ 'auto' , 'gpu' ] | int | None = None ,
472473 nnodes : int | None = None ,
473474 node_rank : int | None = None ,
474- rdzv_id : str | None = None ,
475+ rdzv_id : str | int | None = None ,
475476 rdzv_endpoint : str | None = None ,
477+ master_port : str | None = None ,
478+ master_addr : str | None = None ,
476479 ** kwargs
477480) -> any :
478481 from . import create_algorithm
@@ -504,5 +507,7 @@ def osft(
504507 node_rank = node_rank ,
505508 rdzv_id = rdzv_id ,
506509 rdzv_endpoint = rdzv_endpoint ,
510+ master_port = master_port ,
511+ master_addr = master_addr ,
507512 ** kwargs
508513 )
0 commit comments