11import os
2- import shutil
3- from typing import Literal , get_origin , get_args , Union
4- from itertools import chain
2+ from typing import get_origin , get_args , Union
53from dataclasses import fields
64
75import datasets
@@ -28,7 +26,7 @@ def train(
2826 model_path : str ,
2927 data_path : str ,
3028 unfreeze_rank_ratio : float ,
31- batch_size : int ,
29+ effective_batch_size : int ,
3230 max_tokens_per_gpu : int ,
3331 max_seq_len : int ,
3432 learning_rate : float ,
@@ -52,7 +50,7 @@ def train(
5250 save_final_checkpoint : bool | None = None ,
5351
5452 # parameters for the training mode
55- epochs : int | None = None ,
53+ num_epochs : int | None = None ,
5654
5755 # whether to use the processed dataset
5856 use_processed_dataset : bool | None = None ,
@@ -87,7 +85,7 @@ def train(
8785 unfreeze_rank_ratio (float):
8886 Controls the amount that each matrix is unfrozen during OSFT.
8987 Valid values are between 0.0 and 1.0.
90- batch_size (int): Batch size for training.
88+ effective_batch_size (int): Effective batch size for training.
9189 max_tokens_per_gpu (int):
9290 The maximum number of tokens placed on a single GPU for training.
9391 When hitting OOMs, consider reducing this value.
@@ -109,7 +107,7 @@ def train(
109107 lr_scheduler_kwargs (dict[str, str]): Additional scheduler parameters.
110108 checkpoint_at_epoch (bool): Whether to checkpoint at each epoch.
111109 save_final_checkpoint (bool): Whether to save final checkpoint once training is complete.
112- epochs (int): Number of epochs to train for.
110+ num_epochs (int): Number of epochs to train for.
113111 use_processed_dataset (bool):
114112 Whether to use the processed dataset. If False, the data is assumed to be in standard
115113 messages format witha `messages` and optional `unmask` field on each sample.
@@ -137,7 +135,7 @@ def train(
137135 required_params = {
138136 'model_path' : model_path ,
139137 'data_path' : data_path ,
140- 'batch_size ' : batch_size ,
138+ 'effective_batch_size ' : effective_batch_size ,
141139 'max_tokens_per_gpu' : max_tokens_per_gpu ,
142140 'max_seq_len' : max_seq_len ,
143141 'learning_rate' : learning_rate ,
@@ -161,7 +159,7 @@ def train(
161159 'checkpoint_at_epoch' : checkpoint_at_epoch ,
162160 'save_final_checkpoint' : save_final_checkpoint ,
163161
164- 'epochs ' : epochs ,
162+ 'num_epochs ' : num_epochs ,
165163
166164 'use_liger' : use_liger ,
167165 'seed' : seed ,
@@ -196,7 +194,7 @@ def get_required_params(self) -> dict[str, type]:
196194 'model_path' : str ,
197195 'data_path' : str ,
198196 'unfreeze_rank_ratio' : float ,
199- 'batch_size ' : int ,
197+ 'effective_batch_size ' : int ,
200198 'max_tokens_per_gpu' : int ,
201199 'max_seq_len' : int ,
202200 'learning_rate' : float ,
@@ -214,7 +212,7 @@ def get_optional_params(self) -> dict[str, type]:
214212 'lr_scheduler_kwargs' : dict [str , str ],
215213 'checkpoint_at_epoch' : bool ,
216214 'save_final_checkpoint' : bool ,
217- 'epochs ' : int ,
215+ 'num_epochs ' : int ,
218216 'use_processed_dataset' : bool ,
219217 'unmask_messages' : bool ,
220218 'nproc_per_node' : int ,
@@ -320,7 +318,8 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
320318 'target_patterns' : 'osft_target_patterns' ,
321319 'unfreeze_rank_ratio' : 'osft_unfreeze_rank_ratio' ,
322320 'model_path' : 'model_name_or_path' ,
323- 'epochs' : 'max_epochs' ,
321+ 'num_epochs' : 'max_epochs' ,
322+ 'effective_batch_size' : 'batch_size' ,
324323 }
325324
326325 # Rename parameters before sending to backend
@@ -346,11 +345,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
346345 # adjust arguments to align with the API definition
347346 training_args_pre = {k : v for k , v in algorithm_params .items () if k in training_args_fields and v is not None }
348347 training_args_pre ['data_path' ] = training_ready_data_path # replaces raw data path with processed
349-
348+
350349 # mini trainer can support multiple modes, but we don't expose this feature by default
351350 # to prevent the current API from becoming overly complicated
352- training_args_pre ['training_mode' ] = TrainingMode (training_args_pre .get ('training_mode' , 'epoch' ))
353- training_args_pre ['osft' ] = True
351+ if not isinstance (train_mode := training_args_pre .get ('training_mode' , TrainingMode .EPOCH ), TrainingMode ):
352+ train_mode = TrainingMode (train_mode )
353+ training_args_pre ['training_mode' ] = train_mode
354+
355+ # user may want to control this API field for debug purposes, so we allow for it to be read
356+ # but default it to True
357+ training_args_pre ['osft' ] = training_args_pre .get ('osft' , True )
354358
355359 torchrun_args_pre = {k : v for k , v in algorithm_params .items () if k in torchrun_args_fields and v is not None }
356360
@@ -419,7 +423,7 @@ def osft(
419423 data_path : str ,
420424 output_dir : str ,
421425 unfreeze_rank_ratio : float ,
422- batch_size : int ,
426+ effective_batch_size : int ,
423427 max_tokens_per_gpu : int ,
424428 max_seq_len : int ,
425429 learning_rate : float ,
@@ -435,7 +439,7 @@ def osft(
435439 lr_scheduler_kwargs : dict [str , str ] | None = None ,
436440 checkpoint_at_epoch : bool | None = None ,
437441 save_final_checkpoint : bool | None = None ,
438- epochs : int | None = None ,
442+ num_epochs : int | None = None ,
439443 # Torchrun parameters for multi-node support
440444 nproc_per_node : int | None = None ,
441445 nnodes : int | None = None ,
@@ -452,7 +456,7 @@ def osft(
452456 data_path = data_path ,
453457 output_dir = output_dir ,
454458 unfreeze_rank_ratio = unfreeze_rank_ratio ,
455- batch_size = batch_size ,
459+ effective_batch_size = effective_batch_size ,
456460 max_tokens_per_gpu = max_tokens_per_gpu ,
457461 max_seq_len = max_seq_len ,
458462 learning_rate = learning_rate ,
@@ -466,7 +470,7 @@ def osft(
466470 lr_scheduler_kwargs = lr_scheduler_kwargs ,
467471 checkpoint_at_epoch = checkpoint_at_epoch ,
468472 save_final_checkpoint = save_final_checkpoint ,
469- epochs = epochs ,
473+ num_epochs = num_epochs ,
470474 nproc_per_node = nproc_per_node ,
471475 nnodes = nnodes ,
472476 node_rank = node_rank ,
0 commit comments