11import os
2- import shutil
3- from typing import Literal , get_origin , get_args , Union
4- from itertools import chain
2+ from typing import get_origin , get_args , Union
53from dataclasses import fields
64
75import datasets
@@ -28,11 +26,11 @@ def train(
2826 model_path : str ,
2927 data_path : str ,
3028 unfreeze_rank_ratio : float ,
31- batch_size : int ,
29+ effective_batch_size : int ,
3230 max_tokens_per_gpu : int ,
3331 max_seq_len : int ,
3432 learning_rate : float ,
35- output_dir : str ,
33+ ckpt_output_dir : str ,
3634
3735 # patterns that we want to match against when selecting
3836 # modules for OSFT
@@ -52,11 +50,12 @@ def train(
5250 save_final_checkpoint : bool | None = None ,
5351
5452 # parameters for the training mode
55- epochs : int | None = None ,
53+ num_epochs : int | None = None ,
5654
5755 # whether to use the processed dataset
5856 use_processed_dataset : bool | None = None ,
5957 unmask_messages : bool | None = None ,
58+ data_output_dir : str | None = None ,
6059
6160 # Torchrun parameters for multi-node support
6261 nproc_per_node : int | None = None ,
@@ -87,16 +86,16 @@ def train(
8786 unfreeze_rank_ratio (float):
8887 Controls the amount that each matrix is unfrozen during OSFT.
8988 Valid values are between 0.0 and 1.0.
90- batch_size (int): Batch size for training.
89+ effective_batch_size (int): Effective batch size for training.
9190 max_tokens_per_gpu (int):
9291 The maximum number of tokens placed on a single GPU for training.
9392 When hitting OOMs, consider reducing this value.
9493 max_seq_len (int):
9594 Sets the maximum sequence length (in tokens) of samples that will be used for training.
9695 Any sample exceeding this length will be dropped from the dataset.
9796 learning_rate (float): Learning rate for model update size.
98- output_dir (str):
99- Directory where outputs from training will be saved, including checkpoints, logs, and
97+ ckpt_output_dir (str):
98+ Directory where outputs from training will be saved such as checkpoints and logs.
10099 any necessary intermediate files.
101100 target_patterns (list[str]):
102101 List of patterns to match against when selecting modules for OSFT,
@@ -109,7 +108,7 @@ def train(
109108 lr_scheduler_kwargs (dict[str, str]): Additional scheduler parameters.
110109 checkpoint_at_epoch (bool): Whether to checkpoint at each epoch.
111110 save_final_checkpoint (bool): Whether to save final checkpoint once training is complete.
112- epochs (int): Number of epochs to train for.
111+ num_epochs (int): Number of epochs to train for.
113112 use_processed_dataset (bool):
114113 Whether to use the processed dataset. If False, the data is assumed to be in standard
115114 messages format witha `messages` and optional `unmask` field on each sample.
@@ -118,6 +117,10 @@ def train(
118117 unmask_messages (bool):
119118 Whether to unmask messages during data processing. This value is ignored
120119 when `use_processed_dataset` is True.
120+ data_output_dir (str):
121+ Directory where outputs from data processing will be saved such as intermediate
122+ files. When not provided, it defaults to `_internal_data_processing` under the
123+ `ckpt_output_dir`.
121124 nproc_per_node (int): Number of processes (GPUs) per node for distributed training.
122125 nnodes (int): Total number of nodes for distributed training.
123126 node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
@@ -137,11 +140,11 @@ def train(
137140 required_params = {
138141 'model_path' : model_path ,
139142 'data_path' : data_path ,
140- 'batch_size ' : batch_size ,
143+ 'effective_batch_size ' : effective_batch_size ,
141144 'max_tokens_per_gpu' : max_tokens_per_gpu ,
142145 'max_seq_len' : max_seq_len ,
143146 'learning_rate' : learning_rate ,
144- 'output_dir ' : output_dir ,
147+ 'ckpt_output_dir ' : ckpt_output_dir ,
145148 'unfreeze_rank_ratio' : unfreeze_rank_ratio ,
146149 }
147150
@@ -151,6 +154,7 @@ def train(
151154 # for data processing
152155 'use_processed_dataset' : use_processed_dataset ,
153156 'unmask_messages' : unmask_messages ,
157+ 'data_output_dir' : data_output_dir ,
154158
155159 # scheduler params
156160 'lr_scheduler' : lr_scheduler ,
@@ -161,7 +165,7 @@ def train(
161165 'checkpoint_at_epoch' : checkpoint_at_epoch ,
162166 'save_final_checkpoint' : save_final_checkpoint ,
163167
164- 'epochs ' : epochs ,
168+ 'num_epochs ' : num_epochs ,
165169
166170 'use_liger' : use_liger ,
167171 'seed' : seed ,
@@ -196,11 +200,11 @@ def get_required_params(self) -> dict[str, type]:
196200 'model_path' : str ,
197201 'data_path' : str ,
198202 'unfreeze_rank_ratio' : float ,
199- 'batch_size ' : int ,
203+ 'effective_batch_size ' : int ,
200204 'max_tokens_per_gpu' : int ,
201205 'max_seq_len' : int ,
202206 'learning_rate' : float ,
203- 'output_dir ' : str ,
207+ 'ckpt_output_dir ' : str ,
204208 }
205209
206210 def get_optional_params (self ) -> dict [str , type ]:
@@ -214,9 +218,10 @@ def get_optional_params(self) -> dict[str, type]:
214218 'lr_scheduler_kwargs' : dict [str , str ],
215219 'checkpoint_at_epoch' : bool ,
216220 'save_final_checkpoint' : bool ,
217- 'epochs ' : int ,
221+ 'num_epochs ' : int ,
218222 'use_processed_dataset' : bool ,
219223 'unmask_messages' : bool ,
224+ 'data_output_dir' : str ,
220225 'nproc_per_node' : int ,
221226 'nnodes' : int ,
222227 'node_rank' : int ,
@@ -320,18 +325,28 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
320325 'target_patterns' : 'osft_target_patterns' ,
321326 'unfreeze_rank_ratio' : 'osft_unfreeze_rank_ratio' ,
322327 'model_path' : 'model_name_or_path' ,
323- 'epochs' : 'max_epochs' ,
328+ 'num_epochs' : 'max_epochs' ,
329+ 'effective_batch_size' : 'batch_size' ,
330+ 'ckpt_output_dir' : 'output_dir' ,
324331 }
325332
326333 # Rename parameters before sending to backend
327334 algorithm_params = {renames .get (k , k ): v for k , v in algorithm_params .items ()}
335+
336+ # We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
337+ # proceessing. But we do not want to make assumptions about the size of training data or the
338+ # amount of memory on the host. So by default we write to storage, but expose this as a separate
339+ # parameter for performaance gains.
340+ data_output_dir = algorithm_params .get ('data_output_dir' , None )
341+ if data_output_dir is None :
342+ data_output_dir = os .path .join (algorithm_params ['ckpt_output_dir' ], '_internal_data_processing' )
328343
329344 # since mini trainer itself does not process data, we delegate this to
330345 # a separate backend, and expect to receive the correct data path
331346 training_ready_data_path = self ._process_data (
332347 data_path = algorithm_params ['data_path' ], # should be there
333348 model_name_or_path = algorithm_params ['model_name_or_path' ], # should be there
334- output_dir = algorithm_params [ 'output_dir' ] ,
349+ output_dir = data_output_dir ,
335350 max_seq_len = algorithm_params ['max_seq_len' ],
336351 num_cpu_procs = 8 , # this is a safe default
337352 use_processed_dataset = algorithm_params .get ('use_processed_dataset' , False ),
@@ -346,11 +361,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
346361 # adjust arguments to align with the API definition
347362 training_args_pre = {k : v for k , v in algorithm_params .items () if k in training_args_fields and v is not None }
348363 training_args_pre ['data_path' ] = training_ready_data_path # replaces raw data path with processed
349-
364+
350365 # mini trainer can support multiple modes, but we don't expose this feature by default
351366 # to prevent the current API from becoming overly complicated
352- training_args_pre ['training_mode' ] = TrainingMode (training_args_pre .get ('training_mode' , 'epoch' ))
353- training_args_pre ['osft' ] = True
367+ if not isinstance (train_mode := training_args_pre .get ('training_mode' , TrainingMode .EPOCH ), TrainingMode ):
368+ train_mode = TrainingMode (train_mode )
369+ training_args_pre ['training_mode' ] = train_mode
370+
371+ # user may want to control this API field for debug purposes, so we allow for it to be read
372+ # but default it to True
373+ training_args_pre ['osft' ] = training_args_pre .get ('osft' , True )
354374
355375 torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
356376
@@ -384,28 +404,27 @@ def _process_data(
384404 return data_path
385405
386406 # otherwise we need to process the data
387- data_output_path = os .path .join (output_dir , '_internal_data_processing' )
388- os .makedirs (data_output_path , exist_ok = True )
407+ os .makedirs (output_dir , exist_ok = True )
389408
390409 # if we received unmask then we need to add that
391410 processing_data_path = data_path
392411 if unmask_messages :
393412 ds = datasets .load_dataset (data_path , split = 'train' )
394413 ds = ds .map (lambda _ : { "unmask" : True })
395- processing_data_path = os .path .join (data_output_path , 'intermediate_data.jsonl' )
414+ processing_data_path = os .path .join (output_dir , 'intermediate_data.jsonl' )
396415 ds .to_json (processing_data_path )
397416
398417 # now we process the data
399418 process_messages_into_input_ids (
400419 data_path = processing_data_path ,
401- data_output_path = data_output_path ,
420+ data_output_path = output_dir ,
402421 model_path = model_name_or_path ,
403422 max_seq_len = max_seq_len ,
404423 num_cpu_procs = num_cpu_procs ,
405424 )
406425
407426 # above function will save to this file, so we pass this to the trainer
408- return os .path .join (data_output_path , 'data.jsonl' )
427+ return os .path .join (output_dir , 'data.jsonl' )
409428
410429
411430
@@ -417,12 +436,13 @@ def _process_data(
417436def osft (
418437 model_path : str ,
419438 data_path : str ,
420- output_dir : str ,
421439 unfreeze_rank_ratio : float ,
422- batch_size : int ,
440+ effective_batch_size : int ,
423441 max_tokens_per_gpu : int ,
424442 max_seq_len : int ,
425443 learning_rate : float ,
444+ ckpt_output_dir : str ,
445+ data_output_dir : str | None = None ,
426446 backend : str = "mini-trainer" ,
427447 # Optional parameters
428448 target_patterns : list [str ] | None = None ,
@@ -435,7 +455,7 @@ def osft(
435455 lr_scheduler_kwargs : dict [str , str ] | None = None ,
436456 checkpoint_at_epoch : bool | None = None ,
437457 save_final_checkpoint : bool | None = None ,
438- epochs : int | None = None ,
458+ num_epochs : int | None = None ,
439459 # Torchrun parameters for multi-node support
440460 nproc_per_node : int | None = None ,
441461 nnodes : int | None = None ,
@@ -450,9 +470,10 @@ def osft(
450470 return algorithm .train (
451471 model_path = model_path ,
452472 data_path = data_path ,
453- output_dir = output_dir ,
473+ ckpt_output_dir = ckpt_output_dir ,
474+ data_output_dir = data_output_dir ,
454475 unfreeze_rank_ratio = unfreeze_rank_ratio ,
455- batch_size = batch_size ,
476+ effective_batch_size = effective_batch_size ,
456477 max_tokens_per_gpu = max_tokens_per_gpu ,
457478 max_seq_len = max_seq_len ,
458479 learning_rate = learning_rate ,
@@ -466,7 +487,7 @@ def osft(
466487 lr_scheduler_kwargs = lr_scheduler_kwargs ,
467488 checkpoint_at_epoch = checkpoint_at_epoch ,
468489 save_final_checkpoint = save_final_checkpoint ,
469- epochs = epochs ,
490+ num_epochs = num_epochs ,
470491 nproc_per_node = nproc_per_node ,
471492 nnodes = nnodes ,
472493 node_rank = node_rank ,
0 commit comments