Skip to content

Commit 2c7f53b

Browse files
committed
allow OSFT to be configurable
1 parent 87bd02a commit 2c7f53b

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ result = osft(
5757
data_path="/path/to/data.jsonl",
5858
output_dir="/path/to/outputs",
5959
unfreeze_rank_ratio=0.3,
60-
batch_size=8,
60+
effective_batch_size=8,
6161
max_tokens_per_gpu=2048,
6262
max_seq_len=2048,
6363
learning_rate=2e-5

examples/docs/osft_usage.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ result = osft(
7171
data_path="/path/to/your/training/data.jsonl",
7272
output_dir="/path/to/save/outputs",
7373
unfreeze_rank_ratio=0.3,
74-
batch_size=8,
74+
effective_batch_size=16,
7575
max_tokens_per_gpu=2048,
7676
max_seq_len=2048,
7777
learning_rate=2e-5
@@ -83,11 +83,11 @@ result = osft(
8383
data_path="/path/to/your/training/data.jsonl",
8484
output_dir="/path/to/save/outputs",
8585
unfreeze_rank_ratio=0.2,
86-
batch_size=4,
86+
effective_batch_size=16,
8787
max_tokens_per_gpu=4096,
8888
max_seq_len=4096,
8989
learning_rate=1e-5,
90-
epochs=3,
90+
num_epochs=3,
9191
warmup_steps=100,
9292
use_liger=True,
9393
seed=42
@@ -114,7 +114,7 @@ result = osft_algo.train(
114114
max_tokens_per_gpu=3072,
115115
max_seq_len=2048,
116116
learning_rate=1.5e-5,
117-
epochs=2
117+
num_epochs=2
118118
)
119119

120120
# Check required parameters
@@ -149,7 +149,7 @@ OSFTAlgorithm = AlgorithmRegistry.get_algorithm('osft')
149149
- `data_path` (str): Path to the training data (processed or unprocessed)
150150
- `output_dir` (str): Directory where outputs from training will be saved
151151
- `unfreeze_rank_ratio` (float): Controls the amount that each matrix is unfrozen during OSFT (0.0-1.0)
152-
- `batch_size` (int): Batch size for training
152+
- `effective_batch_size` (int): Batch size for training
153153
- `max_tokens_per_gpu` (int): Maximum number of tokens placed on a single GPU
154154
- `max_seq_len` (int): Maximum sequence length (in tokens) for training samples
155155
- `learning_rate` (float): Learning rate for model update size
@@ -165,7 +165,7 @@ OSFTAlgorithm = AlgorithmRegistry.get_algorithm('osft')
165165
- `unmask_messages` (bool): Whether to unmask messages during data processing
166166

167167
**Core Training Parameters:**
168-
- `epochs` (int): Number of epochs to train for
168+
- `num_epochs` (int): Number of epochs to train for
169169
- `seed` (int): Random seed for training
170170
- `use_liger` (bool): Whether to use Liger kernels for training
171171

@@ -200,7 +200,7 @@ try:
200200
data_path="/valid/data/path",
201201
output_dir="/valid/output/path",
202202
unfreeze_rank_ratio=0.3,
203-
batch_size=8,
203+
effective_batch_size=8,
204204
max_tokens_per_gpu=2048,
205205
max_seq_len=2048,
206206
learning_rate=2e-5
@@ -232,7 +232,7 @@ result = osft(
232232
data_path="/path/to/data.jsonl",
233233
output_dir="/path/to/outputs",
234234
unfreeze_rank_ratio=0.3,
235-
batch_size=4,
235+
effective_batch_size=4,
236236
max_tokens_per_gpu=2048,
237237
max_seq_len=2048,
238238
learning_rate=2e-5,
@@ -250,7 +250,7 @@ result = osft(
250250
data_path="/path/to/data.jsonl",
251251
output_dir="/path/to/outputs",
252252
unfreeze_rank_ratio=0.25,
253-
batch_size=2,
253+
effective_batch_size=2,
254254
max_tokens_per_gpu=1024,
255255
max_seq_len=2048,
256256
learning_rate=1e-5,

src/training_hub/algorithms/osft.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
2-
import shutil
3-
from typing import Literal, get_origin, get_args, Union
4-
from itertools import chain
2+
from typing import get_origin, get_args, Union
53
from dataclasses import fields
64

75
import datasets
@@ -28,7 +26,7 @@ def train(
2826
model_path: str,
2927
data_path: str,
3028
unfreeze_rank_ratio: float,
31-
batch_size: int,
29+
effective_batch_size: int,
3230
max_tokens_per_gpu: int,
3331
max_seq_len: int,
3432
learning_rate: float,
@@ -52,7 +50,7 @@ def train(
5250
save_final_checkpoint: bool | None = None,
5351

5452
# parameters for the training mode
55-
epochs: int | None = None,
53+
num_epochs: int | None = None,
5654

5755
# whether to use the processed dataset
5856
use_processed_dataset: bool | None = None,
@@ -87,7 +85,7 @@ def train(
8785
unfreeze_rank_ratio (float):
8886
Controls the amount that each matrix is unfrozen during OSFT.
8987
Valid values are between 0.0 and 1.0.
90-
batch_size (int): Batch size for training.
88+
effective_batch_size (int): Effective batch size for training.
9189
max_tokens_per_gpu (int):
9290
The maximum number of tokens placed on a single GPU for training.
9391
When hitting OOMs, consider reducing this value.
@@ -109,7 +107,7 @@ def train(
109107
lr_scheduler_kwargs (dict[str, str]): Additional scheduler parameters.
110108
checkpoint_at_epoch (bool): Whether to checkpoint at each epoch.
111109
save_final_checkpoint (bool): Whether to save final checkpoint once training is complete.
112-
epochs (int): Number of epochs to train for.
110+
num_epochs (int): Number of epochs to train for.
113111
use_processed_dataset (bool):
114112
Whether to use the processed dataset. If False, the data is assumed to be in standard
115113
messages format witha `messages` and optional `unmask` field on each sample.
@@ -137,7 +135,7 @@ def train(
137135
required_params = {
138136
'model_path': model_path,
139137
'data_path': data_path,
140-
'batch_size': batch_size,
138+
'effective_batch_size': effective_batch_size,
141139
'max_tokens_per_gpu': max_tokens_per_gpu,
142140
'max_seq_len': max_seq_len,
143141
'learning_rate': learning_rate,
@@ -161,7 +159,7 @@ def train(
161159
'checkpoint_at_epoch': checkpoint_at_epoch,
162160
'save_final_checkpoint': save_final_checkpoint,
163161

164-
'epochs': epochs,
162+
'num_epochs': num_epochs,
165163

166164
'use_liger': use_liger,
167165
'seed': seed,
@@ -196,7 +194,7 @@ def get_required_params(self) -> dict[str, type]:
196194
'model_path': str,
197195
'data_path': str,
198196
'unfreeze_rank_ratio': float,
199-
'batch_size': int,
197+
'effective_batch_size': int,
200198
'max_tokens_per_gpu': int,
201199
'max_seq_len': int,
202200
'learning_rate': float,
@@ -214,7 +212,7 @@ def get_optional_params(self) -> dict[str, type]:
214212
'lr_scheduler_kwargs': dict[str, str],
215213
'checkpoint_at_epoch': bool,
216214
'save_final_checkpoint': bool,
217-
'epochs': int,
215+
'num_epochs': int,
218216
'use_processed_dataset': bool,
219217
'unmask_messages': bool,
220218
'nproc_per_node': int,
@@ -320,7 +318,8 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
320318
'target_patterns': 'osft_target_patterns',
321319
'unfreeze_rank_ratio': 'osft_unfreeze_rank_ratio',
322320
'model_path': 'model_name_or_path',
323-
'epochs': 'max_epochs',
321+
'num_epochs': 'max_epochs',
322+
'effective_batch_size': 'batch_size',
324323
}
325324

326325
# Rename parameters before sending to backend
@@ -346,11 +345,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
346345
# adjust arguments to align with the API definition
347346
training_args_pre = {k: v for k, v in algorithm_params.items() if k in training_args_fields and v is not None}
348347
training_args_pre['data_path'] = training_ready_data_path # replaces raw data path with processed
349-
348+
350349
# mini trainer can support multiple modes, but we don't expose this feature by default
351350
# to prevent the current API from becoming overly complicated
352-
training_args_pre['training_mode'] = TrainingMode(training_args_pre.get('training_mode', 'epoch'))
353-
training_args_pre['osft'] = True
351+
if not isinstance(train_mode := training_args_pre.get('training_mode', TrainingMode.EPOCH), TrainingMode):
352+
train_mode = TrainingMode(train_mode)
353+
training_args_pre['training_mode'] = train_mode
354+
355+
# user may want to control this API field for debug purposes, so we allow for it to be read
356+
# but default it to True
357+
training_args_pre['osft'] = training_args_pre.get('osft', True)
354358

355359
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
356360

@@ -419,7 +423,7 @@ def osft(
419423
data_path: str,
420424
output_dir: str,
421425
unfreeze_rank_ratio: float,
422-
batch_size: int,
426+
effective_batch_size: int,
423427
max_tokens_per_gpu: int,
424428
max_seq_len: int,
425429
learning_rate: float,
@@ -435,7 +439,7 @@ def osft(
435439
lr_scheduler_kwargs: dict[str, str] | None = None,
436440
checkpoint_at_epoch: bool | None = None,
437441
save_final_checkpoint: bool | None = None,
438-
epochs: int | None = None,
442+
num_epochs: int | None = None,
439443
# Torchrun parameters for multi-node support
440444
nproc_per_node: int | None = None,
441445
nnodes: int | None = None,
@@ -452,7 +456,7 @@ def osft(
452456
data_path=data_path,
453457
output_dir=output_dir,
454458
unfreeze_rank_ratio=unfreeze_rank_ratio,
455-
batch_size=batch_size,
459+
effective_batch_size=effective_batch_size,
456460
max_tokens_per_gpu=max_tokens_per_gpu,
457461
max_seq_len=max_seq_len,
458462
learning_rate=learning_rate,
@@ -466,7 +470,7 @@ def osft(
466470
lr_scheduler_kwargs=lr_scheduler_kwargs,
467471
checkpoint_at_epoch=checkpoint_at_epoch,
468472
save_final_checkpoint=save_final_checkpoint,
469-
epochs=epochs,
473+
num_epochs=num_epochs,
470474
nproc_per_node=nproc_per_node,
471475
nnodes=nnodes,
472476
node_rank=node_rank,

src/training_hub/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import get_origin, get_args
2-
import sys
32

43
def format_type_name(tp):
54
# Handle None

0 commit comments

Comments
 (0)