1+ #!/usr/bin/env python3
2+ """
3+ OSFT Training Example: GPT-OSS 20B
4+
5+ This script demonstrates OSFT (Orthogonal Subspace Fine-Tuning) training with GPT-OSS 20B model
6+ using a single-node, multi-GPU setup with training_hub.
7+
8+ OSFT allows continual training without catastrophic forgetting, making it ideal for:
9+ - Adapting GPT-OSS 20B to specialized domains (medical, legal, technical)
10+ - Adding new knowledge without degrading general capabilities
11+ - Fine-tuning without complex replay mechanisms
12+
13+ Example usage:
14+ python osft_gpt_oss_example.py \
15+ --data-path /path/to/data.jsonl \
16+ --ckpt-output-dir /path/to/checkpoints
17+ """
18+
19+ import os
20+ import sys
21+ import time
22+ from datetime import datetime
23+ import argparse
24+ import glob
25+
26+ from training_hub import osft
27+
28+ def find_most_recent_checkpoint (output_dir ):
29+ """
30+ Find the most recent checkpoint in the training output directory.
31+
32+ Args:
33+ output_dir (str): Training output directory containing hf_format/ subdirectory
34+
35+ Returns:
36+ str: Path to the most recent checkpoint
37+
38+ Raises:
39+ ValueError: If no checkpoints are found
40+ """
41+ # Get all checkpoint directories under hf_format
42+ checkpoint_pattern = os .path .join (output_dir , "hf_format" , "samples_*" )
43+ checkpoint_dirs = glob .glob (checkpoint_pattern )
44+
45+ if not checkpoint_dirs :
46+ raise ValueError (f"No checkpoints found in { os .path .join (output_dir , 'hf_format' )} " )
47+
48+ # Find the most recently created checkpoint
49+ most_recent_checkpoint = max (checkpoint_dirs , key = os .path .getctime )
50+
51+ return most_recent_checkpoint
52+
53+
54+ def main ():
55+ parser = argparse .ArgumentParser (description = 'OSFT Training Example: GPT-OSS 20B' )
56+
57+ # Required parameters
58+ parser .add_argument ('--data-path' , required = True ,
59+ help = 'Path to training data (JSONL format)' )
60+ parser .add_argument ('--ckpt-output-dir' , required = True ,
61+ help = 'Directory to save checkpoints' )
62+
63+ # Optional overrides
64+ parser .add_argument ('--model-path' , default = 'openai/gpt-oss-20b' ,
65+ help = 'Model path or HuggingFace name (default: openai/gpt-oss-20b)' )
66+ parser .add_argument ('--num-epochs' , type = int , default = 3 ,
67+ help = 'Number of epochs (default: 3)' )
68+ parser .add_argument ('--unfreeze-rank-ratio' , type = float , default = 0.25 ,
69+ help = 'Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.25)' )
70+ parser .add_argument ('--max-tokens-per-gpu' , type = int , default = 8192 ,
71+ help = 'Max tokens per GPU (default: 8192 for GPT-OSS 20B)' )
72+ parser .add_argument ('--nproc-per-node' , type = int , default = 8 ,
73+ help = 'Number of GPUs (default: 8)' )
74+ parser .add_argument ('--unmask-messages' , action = 'store_true' , default = False ,
75+ help = 'Unmask messages during training (default: False)' )
76+ parser .add_argument ('--learning-rate' , type = float , default = 3e-6 ,
77+ help = 'Learning rate for training (default: 3e-6)' )
78+
79+ args = parser .parse_args ()
80+
81+ # GPT-OSS 20B OSFT configuration
82+ print ("🚀 OSFT Training: GPT-OSS 20B" )
83+ print ("=" * 50 )
84+ print (f"Model: { args .model_path } " )
85+ print (f"Data: { args .data_path } " )
86+ print (f"Output: { args .ckpt_output_dir } " )
87+ print (f"GPUs: { args .nproc_per_node } " )
88+ print (f"Unfreeze Rank Ratio: { args .unfreeze_rank_ratio } " )
89+ print (f"Max tokens per GPU: { args .max_tokens_per_gpu :,} " )
90+ print ()
91+ print ("📝 OSFT Benefits for GPT-OSS 20B:" )
92+ print (" • Preserve GPT-OSS's strong general capabilities" )
93+ print (" • Add domain-specific knowledge efficiently" )
94+ print (" • No need for complex data mixing or replay buffers" )
95+ print (" • Leverage the high-quality 20B parameter base" )
96+ print ()
97+
98+ # Training configuration optimized for GPT-OSS 20B with OSFT
99+ start_time = time .time ()
100+
101+ try :
102+ osft_params = {
103+ # Model and data
104+ 'model_path' : args .model_path ,
105+ 'data_path' : args .data_path ,
106+ 'ckpt_output_dir' : args .ckpt_output_dir ,
107+
108+ # OSFT-specific parameters
109+ 'unfreeze_rank_ratio' : args .unfreeze_rank_ratio , # Conservative for 20B model
110+
111+ # Training parameters optimized for GPT-OSS 20B
112+ 'num_epochs' : args .num_epochs ,
113+ 'effective_batch_size' : 32 , # Smaller batch size for 20B model
114+ 'learning_rate' : args .learning_rate , # Lower LR for larger model
115+ 'max_seq_len' : 4096 , # Conservative context length for memory
116+ 'max_tokens_per_gpu' : args .max_tokens_per_gpu ,
117+
118+ # Data processing
119+ 'data_output_dir' : "/dev/shm" , # Use RAM disk for speed
120+ 'warmup_steps' : 0 ,
121+ 'unmask_messages' : args .unmask_messages ,
122+
123+ # Optimization
124+ 'use_liger' : True , # Enable Liger kernels for efficiency
125+ 'osft_memory_efficient_init' : True , # Recommended for OOMs at model load time
126+ 'seed' : 42 ,
127+ 'lr_scheduler' : 'cosine' , # Cosine scheduler works well with OSFT
128+
129+ # Checkpointing
130+ 'checkpoint_at_epoch' : True ,
131+ 'save_final_checkpoint' : True ,
132+
133+ # Single-node multi-GPU setup
134+ 'nproc_per_node' : args .nproc_per_node ,
135+ 'nnodes' : 1 ,
136+ 'node_rank' : 0 ,
137+ 'rdzv_id' : 105 ,
138+ 'rdzv_endpoint' : "127.0.0.1:29500" ,
139+ }
140+
141+
142+ osft (** osft_params )
143+
144+ end_time = time .time ()
145+ duration = end_time - start_time
146+
147+ most_recent_checkpoint = find_most_recent_checkpoint (args .ckpt_output_dir )
148+
149+ print ("=" * 50 )
150+ print ("✅ OSFT Training completed successfully!" )
151+ print (f"⏱️ Duration: { duration / 3600 :.2f} hours" )
152+ print (f"📁 Checkpoints: { args .ckpt_output_dir } /hf_format" )
153+ print (f" Most recent checkpoint: { most_recent_checkpoint } " )
154+ print ()
155+ print ("🎯 Your GPT-OSS 20B model has been successfully adapted!" )
156+ print (" The model now incorporates your domain-specific knowledge" )
157+ print (" while maintaining its original high-quality capabilities." )
158+
159+ except Exception as e :
160+ end_time = time .time ()
161+ duration = end_time - start_time
162+
163+ print ("=" * 50 )
164+ print (f"❌ Training failed after { duration / 60 :.1f} minutes" )
165+ print (f"Error: { e } " )
166+ print ()
167+ print ("💡 Troubleshooting tips:" )
168+ print (" - Reduce --max-tokens-per-gpu if you see OOM errors" )
169+ print (" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.3" )
170+ print (" - Consider reducing batch size further for memory constraints" )
171+ sys .exit (1 )
172+
173+
174+ if __name__ == "__main__" :
175+ main ()
0 commit comments