Skip to content

Commit 9b8505b

Browse files
committed
Add example scripts for GPT-OSS
Signed-off-by: Mustafa Eyceoz <[email protected]>
1 parent 455a61f commit 9b8505b

File tree

3 files changed

+294
-1
lines changed

3 files changed

+294
-1
lines changed

examples/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ The SFT algorithm supports training language models on supervised datasets with
2424
**Scripts:**
2525
- [LAB Multi-Phase Training Script](scripts/lab_multiphase_training.py) - Example script for LAB multi-phase training with full command-line interface
2626
- [SFT with Qwen 2.5 7B](scripts/sft_qwen_example.py) - Single-node multi-GPU training example with Qwen 2.5 7B Instruct
27-
- [SFT with Llama 3.1 8B](scripts/sft_llama_example.py) - Single-node multi-GPU training example with Llama 3.1 8B Instruct
27+
- [SFT with Llama 3.1 8B](scripts/sft_llama_example.py) - Single-node multi-GPU training example with Llama 3.1 8B Instruct
2828
- [SFT with Phi 4 Mini](scripts/sft_phi_example.py) - Single-node multi-GPU training example with Phi 4 Mini Instruct
29+
- [SFT with GPT-OSS 20B](scripts/sft_gpt_oss_example.py) - Single-node multi-GPU training example with GPT-OSS 20B
2930

3031
**Quick Example:**
3132
```python
@@ -58,6 +59,7 @@ The OSFT algorithm supports continual training of pre-trained or instruction-tun
5859
- [OSFT with Qwen 2.5 7B](scripts/osft_qwen_example.py) - Single-node multi-GPU training example with Qwen 2.5 7B Instruct
5960
- [OSFT with Llama 3.1 8B](scripts/osft_llama_example.py) - Single-node multi-GPU training example with Llama 3.1 8B Instruct
6061
- [OSFT with Phi 4 Mini](scripts/osft_phi_example.py) - Single-node multi-GPU training example with Phi 4 Mini Instruct
62+
- [OSFT with GPT-OSS 20B](scripts/osft_gpt_oss_example.py) - Single-node multi-GPU training example with GPT-OSS 20B
6163
- [OSFT Continual Learning Example](scripts/osft_continual_learning_example.py) - Example script demonstrating continual learning without catastrophic forgetting
6264

6365
**Quick Example:**
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#!/usr/bin/env python3
2+
"""
3+
OSFT Training Example: GPT-OSS 20B
4+
5+
This script demonstrates OSFT (Orthogonal Subspace Fine-Tuning) training with GPT-OSS 20B model
6+
using a single-node, multi-GPU setup with training_hub.
7+
8+
OSFT allows continual training without catastrophic forgetting, making it ideal for:
9+
- Adapting GPT-OSS 20B to specialized domains (medical, legal, technical)
10+
- Adding new knowledge without degrading general capabilities
11+
- Fine-tuning without complex replay mechanisms
12+
13+
Example usage:
14+
python osft_gpt_oss_example.py \
15+
--data-path /path/to/data.jsonl \
16+
--ckpt-output-dir /path/to/checkpoints
17+
"""
18+
19+
import os
20+
import sys
21+
import time
22+
from datetime import datetime
23+
import argparse
24+
import glob
25+
26+
from training_hub import osft
27+
28+
def find_most_recent_checkpoint(output_dir):
29+
"""
30+
Find the most recent checkpoint in the training output directory.
31+
32+
Args:
33+
output_dir (str): Training output directory containing hf_format/ subdirectory
34+
35+
Returns:
36+
str: Path to the most recent checkpoint
37+
38+
Raises:
39+
ValueError: If no checkpoints are found
40+
"""
41+
# Get all checkpoint directories under hf_format
42+
checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
43+
checkpoint_dirs = glob.glob(checkpoint_pattern)
44+
45+
if not checkpoint_dirs:
46+
raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
47+
48+
# Find the most recently created checkpoint
49+
most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
50+
51+
return most_recent_checkpoint
52+
53+
54+
def main():
55+
parser = argparse.ArgumentParser(description='OSFT Training Example: GPT-OSS 20B')
56+
57+
# Required parameters
58+
parser.add_argument('--data-path', required=True,
59+
help='Path to training data (JSONL format)')
60+
parser.add_argument('--ckpt-output-dir', required=True,
61+
help='Directory to save checkpoints')
62+
63+
# Optional overrides
64+
parser.add_argument('--model-path', default='openai/gpt-oss-20b',
65+
help='Model path or HuggingFace name (default: openai/gpt-oss-20b)')
66+
parser.add_argument('--num-epochs', type=int, default=3,
67+
help='Number of epochs (default: 3)')
68+
parser.add_argument('--unfreeze-rank-ratio', type=float, default=0.25,
69+
help='Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.25)')
70+
parser.add_argument('--max-tokens-per-gpu', type=int, default=8192,
71+
help='Max tokens per GPU (default: 8192 for GPT-OSS 20B)')
72+
parser.add_argument('--nproc-per-node', type=int, default=8,
73+
help='Number of GPUs (default: 8)')
74+
parser.add_argument('--unmask-messages', action='store_true', default=False,
75+
help='Unmask messages during training (default: False)')
76+
parser.add_argument('--learning-rate', type=float, default=3e-6,
77+
help='Learning rate for training (default: 3e-6)')
78+
79+
args = parser.parse_args()
80+
81+
# GPT-OSS 20B OSFT configuration
82+
print("🚀 OSFT Training: GPT-OSS 20B")
83+
print("=" * 50)
84+
print(f"Model: {args.model_path}")
85+
print(f"Data: {args.data_path}")
86+
print(f"Output: {args.ckpt_output_dir}")
87+
print(f"GPUs: {args.nproc_per_node}")
88+
print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}")
89+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
90+
print()
91+
print("📝 OSFT Benefits for GPT-OSS 20B:")
92+
print(" • Preserve GPT-OSS's strong general capabilities")
93+
print(" • Add domain-specific knowledge efficiently")
94+
print(" • No need for complex data mixing or replay buffers")
95+
print(" • Leverage the high-quality 20B parameter base")
96+
print()
97+
98+
# Training configuration optimized for GPT-OSS 20B with OSFT
99+
start_time = time.time()
100+
101+
try:
102+
osft_params = {
103+
# Model and data
104+
'model_path': args.model_path,
105+
'data_path': args.data_path,
106+
'ckpt_output_dir': args.ckpt_output_dir,
107+
108+
# OSFT-specific parameters
109+
'unfreeze_rank_ratio': args.unfreeze_rank_ratio, # Conservative for 20B model
110+
111+
# Training parameters optimized for GPT-OSS 20B
112+
'num_epochs': args.num_epochs,
113+
'effective_batch_size': 32, # Smaller batch size for 20B model
114+
'learning_rate': args.learning_rate, # Lower LR for larger model
115+
'max_seq_len': 4096, # Conservative context length for memory
116+
'max_tokens_per_gpu': args.max_tokens_per_gpu,
117+
118+
# Data processing
119+
'data_output_dir': "/dev/shm", # Use RAM disk for speed
120+
'warmup_steps': 0,
121+
'unmask_messages': args.unmask_messages,
122+
123+
# Optimization
124+
'use_liger': True, # Enable Liger kernels for efficiency
125+
'seed': 42,
126+
'lr_scheduler': 'cosine', # Cosine scheduler works well with OSFT
127+
128+
# Checkpointing
129+
'checkpoint_at_epoch': True,
130+
'save_final_checkpoint': True,
131+
132+
# Single-node multi-GPU setup
133+
'nproc_per_node': args.nproc_per_node,
134+
'nnodes': 1,
135+
'node_rank': 0,
136+
'rdzv_id': 105,
137+
'rdzv_endpoint': "127.0.0.1:29500",
138+
}
139+
140+
141+
osft(**osft_params)
142+
143+
end_time = time.time()
144+
duration = end_time - start_time
145+
146+
most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir)
147+
148+
print("=" * 50)
149+
print("✅ OSFT Training completed successfully!")
150+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
151+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format")
152+
print(f" Most recent checkpoint: {most_recent_checkpoint}")
153+
print()
154+
print("🎯 Your GPT-OSS 20B model has been successfully adapted!")
155+
print(" The model now incorporates your domain-specific knowledge")
156+
print(" while maintaining its original high-quality capabilities.")
157+
158+
except Exception as e:
159+
end_time = time.time()
160+
duration = end_time - start_time
161+
162+
print("=" * 50)
163+
print(f"❌ Training failed after {duration/60:.1f} minutes")
164+
print(f"Error: {e}")
165+
print()
166+
print("💡 Troubleshooting tips:")
167+
print(" - Reduce --max-tokens-per-gpu if you see OOM errors")
168+
print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.3")
169+
print(" - Consider reducing batch size further for memory constraints")
170+
sys.exit(1)
171+
172+
173+
if __name__ == "__main__":
174+
main()
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#!/usr/bin/env python3
2+
"""
3+
SFT Training Example: GPT-OSS 20B
4+
5+
This script demonstrates SFT training with GPT-OSS 20B model from OpenAI
6+
using a single-node, multi-GPU setup with training_hub.
7+
8+
GPT-OSS 20B is a high-quality open source model that provides excellent
9+
performance for supervised fine-tuning tasks.
10+
11+
Example usage:
12+
python sft_gpt_oss_example.py \
13+
--data-path /path/to/data.jsonl \
14+
--ckpt-output-dir /path/to/checkpoints
15+
"""
16+
17+
import os
18+
import sys
19+
import time
20+
from datetime import datetime
21+
import argparse
22+
23+
from training_hub import sft
24+
25+
26+
def main():
27+
parser = argparse.ArgumentParser(description='SFT Training Example: GPT-OSS 20B')
28+
29+
# Required parameters
30+
parser.add_argument('--data-path', required=True,
31+
help='Path to training data (JSONL format)')
32+
parser.add_argument('--ckpt-output-dir', required=True,
33+
help='Directory to save checkpoints')
34+
35+
# Optional overrides
36+
parser.add_argument('--model-path', default='openai/gpt-oss-20b',
37+
help='Model path or HuggingFace name (default: openai/gpt-oss-20b)')
38+
parser.add_argument('--num-epochs', type=int, default=3,
39+
help='Number of epochs (default: 3)')
40+
parser.add_argument('--max-tokens-per-gpu', type=int, default=12000,
41+
help='Max tokens per GPU (default: 12000 for 20B model)')
42+
parser.add_argument('--nproc-per-node', type=int, default=8,
43+
help='Number of GPUs (default: 8)')
44+
45+
args = parser.parse_args()
46+
47+
# GPT-OSS 20B configuration
48+
print("🚀 SFT Training: GPT-OSS 20B")
49+
print("=" * 50)
50+
print(f"Model: {args.model_path}")
51+
print(f"Data: {args.data_path}")
52+
print(f"Output: {args.ckpt_output_dir}")
53+
print(f"GPUs: {args.nproc_per_node}")
54+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
55+
print()
56+
57+
# Training configuration optimized for GPT-OSS 20B
58+
start_time = time.time()
59+
60+
try:
61+
result = sft(
62+
# Model and data
63+
model_path=args.model_path,
64+
data_path=args.data_path,
65+
ckpt_output_dir=args.ckpt_output_dir,
66+
67+
# Training parameters optimized for GPT-OSS 20B
68+
num_epochs=args.num_epochs,
69+
effective_batch_size=32, # Smaller batch size for 20B model
70+
learning_rate=2e-5, # Conservative LR for larger model
71+
max_seq_len=8192, # Standard context length
72+
max_tokens_per_gpu=args.max_tokens_per_gpu,
73+
74+
# Data processing
75+
data_output_dir="/dev/shm", # Use RAM disk for speed
76+
warmup_steps=100,
77+
save_samples=0, # 0 disables sample-based checkpointing, use epoch-based only
78+
79+
# Checkpointing
80+
checkpoint_at_epoch=True,
81+
accelerate_full_state_at_epoch=False, # Disable for smaller checkpoints (no auto-resumption)
82+
83+
# Single-node multi-GPU setup
84+
nproc_per_node=args.nproc_per_node,
85+
nnodes=1,
86+
node_rank=0,
87+
rdzv_id=104,
88+
rdzv_endpoint="127.0.0.1:29500",
89+
)
90+
91+
end_time = time.time()
92+
duration = end_time - start_time
93+
94+
print("=" * 50)
95+
print("✅ Training completed successfully!")
96+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
97+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/")
98+
print()
99+
print("🎯 Your GPT-OSS 20B model has been fine-tuned!")
100+
print(" The model is now specialized for your specific task")
101+
print(" while maintaining the high quality of the base model.")
102+
103+
except Exception as e:
104+
end_time = time.time()
105+
duration = end_time - start_time
106+
107+
print("=" * 50)
108+
print(f"❌ Training failed after {duration/60:.1f} minutes")
109+
print(f"Error: {e}")
110+
print()
111+
print("💡 Troubleshooting tips:")
112+
print(" - Reduce --max-tokens-per-gpu if you see OOM errors")
113+
sys.exit(1)
114+
115+
116+
if __name__ == "__main__":
117+
main()

0 commit comments

Comments
 (0)