Skip to content

Commit b27f917

Browse files
committed
allow OSFT to be configurable
1 parent 3959cbb commit b27f917

File tree

4 files changed

+71
-51
lines changed

4 files changed

+71
-51
lines changed

examples/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ from training_hub import osft
5555
result = osft(
5656
model_path="/path/to/model",
5757
data_path="/path/to/data.jsonl",
58-
output_dir="/path/to/outputs",
58+
ckpt_output_dir="/path/to/outputs",
5959
unfreeze_rank_ratio=0.3,
60-
batch_size=8,
60+
effective_batch_size=8,
6161
max_tokens_per_gpu=2048,
6262
max_seq_len=2048,
6363
learning_rate=2e-5

examples/docs/osft_usage.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ from training_hub import osft
6969
result = osft(
7070
model_path="/path/to/your/model",
7171
data_path="/path/to/your/training/data.jsonl",
72-
output_dir="/path/to/save/outputs",
72+
ckpt_output_dir="/path/to/save/outputs",
7373
unfreeze_rank_ratio=0.3,
74-
batch_size=8,
74+
effective_batch_size=16,
7575
max_tokens_per_gpu=2048,
7676
max_seq_len=2048,
7777
learning_rate=2e-5
@@ -81,13 +81,13 @@ result = osft(
8181
result = osft(
8282
model_path="/path/to/your/model",
8383
data_path="/path/to/your/training/data.jsonl",
84-
output_dir="/path/to/save/outputs",
84+
ckpt_output_dir="/path/to/save/outputs",
8585
unfreeze_rank_ratio=0.2,
86-
batch_size=4,
86+
effective_batch_size=16,
8787
max_tokens_per_gpu=4096,
8888
max_seq_len=4096,
8989
learning_rate=1e-5,
90-
epochs=3,
90+
num_epochs=3,
9191
warmup_steps=100,
9292
use_liger=True,
9393
seed=42
@@ -108,13 +108,13 @@ osft_algo = create_algorithm('osft', 'mini-trainer')
108108
result = osft_algo.train(
109109
model_path="/path/to/your/model",
110110
data_path="/path/to/your/training/data.jsonl",
111-
output_dir="/path/to/save/outputs",
111+
ckpt_output_dir="/path/to/save/outputs",
112112
unfreeze_rank_ratio=0.25,
113113
batch_size=6,
114114
max_tokens_per_gpu=3072,
115115
max_seq_len=2048,
116116
learning_rate=1.5e-5,
117-
epochs=2
117+
num_epochs=2
118118
)
119119

120120
# Check required parameters
@@ -147,9 +147,9 @@ OSFTAlgorithm = AlgorithmRegistry.get_algorithm('osft')
147147

148148
- `model_path` (str): Local path or HuggingFace model ID to be used for fine-tuning
149149
- `data_path` (str): Path to the training data (processed or unprocessed)
150-
- `output_dir` (str): Directory where outputs from training will be saved
150+
- `ckpt_output_dir` (str): Directory where outputs from training will be saved
151151
- `unfreeze_rank_ratio` (float): Controls the amount that each matrix is unfrozen during OSFT (0.0-1.0)
152-
- `batch_size` (int): Batch size for training
152+
- `effective_batch_size` (int): Batch size for training
153153
- `max_tokens_per_gpu` (int): Maximum number of tokens placed on a single GPU
154154
- `max_seq_len` (int): Maximum sequence length (in tokens) for training samples
155155
- `learning_rate` (float): Learning rate for model update size
@@ -165,7 +165,7 @@ OSFTAlgorithm = AlgorithmRegistry.get_algorithm('osft')
165165
- `unmask_messages` (bool): Whether to unmask messages during data processing
166166

167167
**Core Training Parameters:**
168-
- `epochs` (int): Number of epochs to train for
168+
- `num_epochs` (int): Number of epochs to train for
169169
- `seed` (int): Random seed for training
170170
- `use_liger` (bool): Whether to use Liger kernels for training
171171

@@ -198,9 +198,9 @@ try:
198198
result = osft(
199199
model_path="/valid/model/path",
200200
data_path="/valid/data/path",
201-
output_dir="/valid/output/path",
201+
ckpt_output_dir="/valid/output/path",
202202
unfreeze_rank_ratio=0.3,
203-
batch_size=8,
203+
effective_batch_size=8,
204204
max_tokens_per_gpu=2048,
205205
max_seq_len=2048,
206206
learning_rate=2e-5
@@ -230,9 +230,9 @@ from training_hub import osft
230230
result = osft(
231231
model_path="/path/to/model",
232232
data_path="/path/to/data.jsonl",
233-
output_dir="/path/to/outputs",
233+
ckpt_output_dir="/path/to/outputs",
234234
unfreeze_rank_ratio=0.3,
235-
batch_size=4,
235+
effective_batch_size=4,
236236
max_tokens_per_gpu=2048,
237237
max_seq_len=2048,
238238
learning_rate=2e-5,
@@ -248,9 +248,9 @@ result = osft(
248248
result = osft(
249249
model_path="/path/to/model",
250250
data_path="/path/to/data.jsonl",
251-
output_dir="/path/to/outputs",
251+
ckpt_output_dir="/path/to/outputs",
252252
unfreeze_rank_ratio=0.25,
253-
batch_size=2,
253+
effective_batch_size=2,
254254
max_tokens_per_gpu=1024,
255255
max_seq_len=2048,
256256
learning_rate=1e-5,

src/training_hub/algorithms/osft.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
2-
import shutil
3-
from typing import Literal, get_origin, get_args, Union
4-
from itertools import chain
2+
from typing import get_origin, get_args, Union
53
from dataclasses import fields
64

75
import datasets
@@ -28,11 +26,11 @@ def train(
2826
model_path: str,
2927
data_path: str,
3028
unfreeze_rank_ratio: float,
31-
batch_size: int,
29+
effective_batch_size: int,
3230
max_tokens_per_gpu: int,
3331
max_seq_len: int,
3432
learning_rate: float,
35-
output_dir: str,
33+
ckpt_output_dir: str,
3634

3735
# patterns that we want to match against when selecting
3836
# modules for OSFT
@@ -52,11 +50,12 @@ def train(
5250
save_final_checkpoint: bool | None = None,
5351

5452
# parameters for the training mode
55-
epochs: int | None = None,
53+
num_epochs: int | None = None,
5654

5755
# whether to use the processed dataset
5856
use_processed_dataset: bool | None = None,
5957
unmask_messages: bool | None = None,
58+
data_output_dir: str | None = None,
6059

6160
# Torchrun parameters for multi-node support
6261
nproc_per_node: int | None = None,
@@ -87,16 +86,16 @@ def train(
8786
unfreeze_rank_ratio (float):
8887
Controls the amount that each matrix is unfrozen during OSFT.
8988
Valid values are between 0.0 and 1.0.
90-
batch_size (int): Batch size for training.
89+
effective_batch_size (int): Effective batch size for training.
9190
max_tokens_per_gpu (int):
9291
The maximum number of tokens placed on a single GPU for training.
9392
When hitting OOMs, consider reducing this value.
9493
max_seq_len (int):
9594
Sets the maximum sequence length (in tokens) of samples that will be used for training.
9695
Any sample exceeding this length will be dropped from the dataset.
9796
learning_rate (float): Learning rate for model update size.
98-
output_dir (str):
99-
Directory where outputs from training will be saved, including checkpoints, logs, and
97+
ckpt_output_dir (str):
98+
Directory where outputs from training will be saved such as checkpoints and logs.
10099
any necessary intermediate files.
101100
target_patterns (list[str]):
102101
List of patterns to match against when selecting modules for OSFT,
@@ -109,7 +108,7 @@ def train(
109108
lr_scheduler_kwargs (dict[str, str]): Additional scheduler parameters.
110109
checkpoint_at_epoch (bool): Whether to checkpoint at each epoch.
111110
save_final_checkpoint (bool): Whether to save final checkpoint once training is complete.
112-
epochs (int): Number of epochs to train for.
111+
num_epochs (int): Number of epochs to train for.
113112
use_processed_dataset (bool):
114113
Whether to use the processed dataset. If False, the data is assumed to be in standard
115114
messages format witha `messages` and optional `unmask` field on each sample.
@@ -118,6 +117,10 @@ def train(
118117
unmask_messages (bool):
119118
Whether to unmask messages during data processing. This value is ignored
120119
when `use_processed_dataset` is True.
120+
data_output_dir (str):
121+
Directory where outputs from data processing will be saved such as intermediate
122+
files. When not provided, it defaults to `_internal_data_processing` under the
123+
`ckpt_output_dir`.
121124
nproc_per_node (int): Number of processes (GPUs) per node for distributed training.
122125
nnodes (int): Total number of nodes for distributed training.
123126
node_rank (int): Rank of this node (0 to nnodes-1) for distributed training.
@@ -137,11 +140,11 @@ def train(
137140
required_params = {
138141
'model_path': model_path,
139142
'data_path': data_path,
140-
'batch_size': batch_size,
143+
'effective_batch_size': effective_batch_size,
141144
'max_tokens_per_gpu': max_tokens_per_gpu,
142145
'max_seq_len': max_seq_len,
143146
'learning_rate': learning_rate,
144-
'output_dir': output_dir,
147+
'ckpt_output_dir': ckpt_output_dir,
145148
'unfreeze_rank_ratio': unfreeze_rank_ratio,
146149
}
147150

@@ -151,6 +154,7 @@ def train(
151154
# for data processing
152155
'use_processed_dataset': use_processed_dataset,
153156
'unmask_messages': unmask_messages,
157+
'data_output_dir': data_output_dir,
154158

155159
# scheduler params
156160
'lr_scheduler': lr_scheduler,
@@ -161,7 +165,7 @@ def train(
161165
'checkpoint_at_epoch': checkpoint_at_epoch,
162166
'save_final_checkpoint': save_final_checkpoint,
163167

164-
'epochs': epochs,
168+
'num_epochs': num_epochs,
165169

166170
'use_liger': use_liger,
167171
'seed': seed,
@@ -196,11 +200,11 @@ def get_required_params(self) -> dict[str, type]:
196200
'model_path': str,
197201
'data_path': str,
198202
'unfreeze_rank_ratio': float,
199-
'batch_size': int,
203+
'effective_batch_size': int,
200204
'max_tokens_per_gpu': int,
201205
'max_seq_len': int,
202206
'learning_rate': float,
203-
'output_dir': str,
207+
'ckpt_output_dir': str,
204208
}
205209

206210
def get_optional_params(self) -> dict[str, type]:
@@ -214,9 +218,10 @@ def get_optional_params(self) -> dict[str, type]:
214218
'lr_scheduler_kwargs': dict[str, str],
215219
'checkpoint_at_epoch': bool,
216220
'save_final_checkpoint': bool,
217-
'epochs': int,
221+
'num_epochs': int,
218222
'use_processed_dataset': bool,
219223
'unmask_messages': bool,
224+
'data_output_dir': str,
220225
'nproc_per_node': int,
221226
'nnodes': int,
222227
'node_rank': int,
@@ -320,18 +325,28 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
320325
'target_patterns': 'osft_target_patterns',
321326
'unfreeze_rank_ratio': 'osft_unfreeze_rank_ratio',
322327
'model_path': 'model_name_or_path',
323-
'epochs': 'max_epochs',
328+
'num_epochs': 'max_epochs',
329+
'effective_batch_size': 'batch_size',
330+
'ckpt_output_dir': 'output_dir',
324331
}
325332

326333
# Rename parameters before sending to backend
327334
algorithm_params = {renames.get(k, k): v for k, v in algorithm_params.items()}
335+
336+
# We separate this from `ckpt_output_dir` so that we can use `/dev/shm` for low-latency data
337+
# proceessing. But we do not want to make assumptions about the size of training data or the
338+
# amount of memory on the host. So by default we write to storage, but expose this as a separate
339+
# parameter for performaance gains.
340+
data_output_dir = algorithm_params.get('data_output_dir', None)
341+
if data_output_dir is None:
342+
data_output_dir = os.path.join(algorithm_params['ckpt_output_dir'], '_internal_data_processing')
328343

329344
# since mini trainer itself does not process data, we delegate this to
330345
# a separate backend, and expect to receive the correct data path
331346
training_ready_data_path = self._process_data(
332347
data_path=algorithm_params['data_path'], # should be there
333348
model_name_or_path=algorithm_params['model_name_or_path'], # should be there
334-
output_dir=algorithm_params['output_dir'],
349+
output_dir=data_output_dir,
335350
max_seq_len=algorithm_params['max_seq_len'],
336351
num_cpu_procs=8, # this is a safe default
337352
use_processed_dataset=algorithm_params.get('use_processed_dataset', False),
@@ -346,11 +361,16 @@ def execute_training(self, algorithm_params: dict[str, any]) -> any:
346361
# adjust arguments to align with the API definition
347362
training_args_pre = {k: v for k, v in algorithm_params.items() if k in training_args_fields and v is not None}
348363
training_args_pre['data_path'] = training_ready_data_path # replaces raw data path with processed
349-
364+
350365
# mini trainer can support multiple modes, but we don't expose this feature by default
351366
# to prevent the current API from becoming overly complicated
352-
training_args_pre['training_mode'] = TrainingMode(training_args_pre.get('training_mode', 'epoch'))
353-
training_args_pre['osft'] = True
367+
if not isinstance(train_mode := training_args_pre.get('training_mode', TrainingMode.EPOCH), TrainingMode):
368+
train_mode = TrainingMode(train_mode)
369+
training_args_pre['training_mode'] = train_mode
370+
371+
# user may want to control this API field for debug purposes, so we allow for it to be read
372+
# but default it to True
373+
training_args_pre['osft'] = training_args_pre.get('osft', True)
354374

355375
torchrun_args_pre = {k: v for k, v in algorithm_params.items() if k in torchrun_args_fields and v is not None}
356376

@@ -384,28 +404,27 @@ def _process_data(
384404
return data_path
385405

386406
# otherwise we need to process the data
387-
data_output_path = os.path.join(output_dir, '_internal_data_processing')
388-
os.makedirs(data_output_path, exist_ok=True)
407+
os.makedirs(output_dir, exist_ok=True)
389408

390409
# if we received unmask then we need to add that
391410
processing_data_path = data_path
392411
if unmask_messages:
393412
ds = datasets.load_dataset(data_path, split='train')
394413
ds = ds.map(lambda _: { "unmask": True })
395-
processing_data_path = os.path.join(data_output_path, 'intermediate_data.jsonl')
414+
processing_data_path = os.path.join(output_dir, 'intermediate_data.jsonl')
396415
ds.to_json(processing_data_path)
397416

398417
# now we process the data
399418
process_messages_into_input_ids(
400419
data_path=processing_data_path,
401-
data_output_path=data_output_path,
420+
data_output_path=output_dir,
402421
model_path=model_name_or_path,
403422
max_seq_len=max_seq_len,
404423
num_cpu_procs=num_cpu_procs,
405424
)
406425

407426
# above function will save to this file, so we pass this to the trainer
408-
return os.path.join(data_output_path, 'data.jsonl')
427+
return os.path.join(output_dir, 'data.jsonl')
409428

410429

411430

@@ -417,12 +436,13 @@ def _process_data(
417436
def osft(
418437
model_path: str,
419438
data_path: str,
420-
output_dir: str,
421439
unfreeze_rank_ratio: float,
422-
batch_size: int,
440+
effective_batch_size: int,
423441
max_tokens_per_gpu: int,
424442
max_seq_len: int,
425443
learning_rate: float,
444+
ckpt_output_dir: str,
445+
data_output_dir: str | None = None,
426446
backend: str = "mini-trainer",
427447
# Optional parameters
428448
target_patterns: list[str] | None = None,
@@ -435,7 +455,7 @@ def osft(
435455
lr_scheduler_kwargs: dict[str, str] | None = None,
436456
checkpoint_at_epoch: bool | None = None,
437457
save_final_checkpoint: bool | None = None,
438-
epochs: int | None = None,
458+
num_epochs: int | None = None,
439459
# Torchrun parameters for multi-node support
440460
nproc_per_node: int | None = None,
441461
nnodes: int | None = None,
@@ -450,9 +470,10 @@ def osft(
450470
return algorithm.train(
451471
model_path=model_path,
452472
data_path=data_path,
453-
output_dir=output_dir,
473+
ckpt_output_dir=ckpt_output_dir,
474+
data_output_dir=data_output_dir,
454475
unfreeze_rank_ratio=unfreeze_rank_ratio,
455-
batch_size=batch_size,
476+
effective_batch_size=effective_batch_size,
456477
max_tokens_per_gpu=max_tokens_per_gpu,
457478
max_seq_len=max_seq_len,
458479
learning_rate=learning_rate,
@@ -466,7 +487,7 @@ def osft(
466487
lr_scheduler_kwargs=lr_scheduler_kwargs,
467488
checkpoint_at_epoch=checkpoint_at_epoch,
468489
save_final_checkpoint=save_final_checkpoint,
469-
epochs=epochs,
490+
num_epochs=num_epochs,
470491
nproc_per_node=nproc_per_node,
471492
nnodes=nnodes,
472493
node_rank=node_rank,

src/training_hub/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import get_origin, get_args
2-
import sys
32

43
def format_type_name(tp):
54
# Handle None

0 commit comments

Comments
 (0)