Skip to content

Commit a3dbb85

Browse files
committed
allow OSFT to be configurable
1 parent 87bd02a commit a3dbb85

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/training_hub/algorithms/osft.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
346346
# adjust arguments to align with the API definition
347347
training_args_pre = {k: v for k, v in algorithm_params.items() if k in training_args_fields and v is not None}
348348
training_args_pre['data_path'] = training_ready_data_path # replaces raw data path with processed
349-
349+
350350
# mini trainer can support multiple modes, but we don't expose this feature by default
351351
# to prevent the current API from becoming overly complicated
352-
training_args_pre['training_mode'] = TrainingMode(training_args_pre.get('training_mode', 'epoch'))
353-
training_args_pre['osft'] = True
352+
if not isinstance(train_mode := training_args_pre.get('training_mode', TrainingMode.EPOCH), TrainingMode):
353+
train_mode = TrainingMode(train_mode)
354+
training_args_pre['training_mode'] = train_mode
355+
356+
# user may want to control this API field for debug purposes, so we allow for it to be read
357+
# but default it to True
358+
training_args_pre['osft'] = training_args_pre.get('osft', True)
354359

355360
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
356361

0 commit comments

Comments
 (0)