1+ #!/usr/bin/env python3
2+ """
3+ OSFT Training Example: GPT-OSS 20B
4+
5+ This script demonstrates OSFT (Orthogonal Subspace Fine-Tuning) training with GPT-OSS 20B model
6+ using a single-node, multi-GPU setup with training_hub.
7+
8+ OSFT allows continual training without catastrophic forgetting, making it ideal for:
9+ - Adapting GPT-OSS 20B to specialized domains (medical, legal, technical)
10+ - Adding new knowledge without degrading general capabilities
11+ - Fine-tuning without complex replay mechanisms
12+
13+ Example usage:
14+ python osft_gpt_oss_example.py \
15+ --data-path /path/to/data.jsonl \
16+ --ckpt-output-dir /path/to/checkpoints
17+ """
18+
19+ import os
20+ import sys
21+ import time
22+ from datetime import datetime
23+ import argparse
24+ import glob
25+
26+ from training_hub import osft
27+
28+ def find_most_recent_checkpoint (output_dir ):
29+ """
30+ Find the most recent checkpoint in the training output directory.
31+
32+ Args:
33+ output_dir (str): Training output directory containing hf_format/ subdirectory
34+
35+ Returns:
36+ str: Path to the most recent checkpoint
37+
38+ Raises:
39+ ValueError: If no checkpoints are found
40+ """
41+ # Get all checkpoint directories under hf_format
42+ checkpoint_pattern = os .path .join (output_dir , "hf_format" , "samples_*.0" )
43+ checkpoint_dirs = glob .glob (checkpoint_pattern )
44+
45+ if not checkpoint_dirs :
46+ raise ValueError (f"No checkpoints found in { os .path .join (output_dir , 'hf_format' )} " )
47+
48+ # Find the most recently created checkpoint
49+ most_recent_checkpoint = max (checkpoint_dirs , key = os .path .getctime )
50+
51+ return most_recent_checkpoint
52+
53+
54+ def main ():
55+ parser = argparse .ArgumentParser (description = 'OSFT Training Example: GPT-OSS 20B' )
56+
57+ # Required parameters
58+ parser .add_argument ('--data-path' , required = True ,
59+ help = 'Path to training data (JSONL format)' )
60+ parser .add_argument ('--ckpt-output-dir' , required = True ,
61+ help = 'Directory to save checkpoints' )
62+
63+ # Optional overrides
64+ parser .add_argument ('--model-path' , default = 'openai/gpt-oss-20b' ,
65+ help = 'Model path or HuggingFace name (default: openai/gpt-oss-20b)' )
66+ parser .add_argument ('--num-epochs' , type = int , default = 3 ,
67+ help = 'Number of epochs (default: 3)' )
68+ parser .add_argument ('--unfreeze-rank-ratio' , type = float , default = 0.25 ,
69+ help = 'Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.25)' )
70+ parser .add_argument ('--max-tokens-per-gpu' , type = int , default = 8192 ,
71+ help = 'Max tokens per GPU (default: 8192 for GPT-OSS 20B)' )
72+ parser .add_argument ('--nproc-per-node' , type = int , default = 8 ,
73+ help = 'Number of GPUs (default: 8)' )
74+ parser .add_argument ('--unmask-messages' , action = 'store_true' , default = False ,
75+ help = 'Unmask messages during training (default: False)' )
76+ parser .add_argument ('--learning-rate' , type = float , default = 3e-6 ,
77+ help = 'Learning rate for training (default: 3e-6)' )
78+
79+ args = parser .parse_args ()
80+
81+ # GPT-OSS 20B OSFT configuration
82+ print ("🚀 OSFT Training: GPT-OSS 20B" )
83+ print ("=" * 50 )
84+ print (f"Model: { args .model_path } " )
85+ print (f"Data: { args .data_path } " )
86+ print (f"Output: { args .ckpt_output_dir } " )
87+ print (f"GPUs: { args .nproc_per_node } " )
88+ print (f"Unfreeze Rank Ratio: { args .unfreeze_rank_ratio } " )
89+ print (f"Max tokens per GPU: { args .max_tokens_per_gpu :,} " )
90+ print ()
91+ print ("📝 OSFT Benefits for GPT-OSS 20B:" )
92+ print (" • Preserve GPT-OSS's strong general capabilities" )
93+ print (" • Add domain-specific knowledge efficiently" )
94+ print (" • No need for complex data mixing or replay buffers" )
95+ print (" • Leverage the high-quality 20B parameter base" )
96+ print ()
97+
98+ # Training configuration optimized for GPT-OSS 20B with OSFT
99+ start_time = time .time ()
100+
101+ try :
102+ osft_params = {
103+ # Model and data
104+ 'model_path' : args .model_path ,
105+ 'data_path' : args .data_path ,
106+ 'ckpt_output_dir' : args .ckpt_output_dir ,
107+
108+ # OSFT-specific parameters
109+ 'unfreeze_rank_ratio' : args .unfreeze_rank_ratio , # Conservative for 20B model
110+
111+ # Training parameters optimized for GPT-OSS 20B
112+ 'num_epochs' : args .num_epochs ,
113+ 'effective_batch_size' : 32 , # Smaller batch size for 20B model
114+ 'learning_rate' : args .learning_rate , # Lower LR for larger model
115+ 'max_seq_len' : 4096 , # Conservative context length for memory
116+ 'max_tokens_per_gpu' : args .max_tokens_per_gpu ,
117+
118+ # Data processing
119+ 'data_output_dir' : "/dev/shm" , # Use RAM disk for speed
120+ 'warmup_steps' : 0 ,
121+ 'unmask_messages' : args .unmask_messages ,
122+
123+ # Optimization
124+ 'use_liger' : True , # Enable Liger kernels for efficiency
125+ 'seed' : 42 ,
126+ 'lr_scheduler' : 'cosine' , # Cosine scheduler works well with OSFT
127+
128+ # Checkpointing
129+ 'checkpoint_at_epoch' : True ,
130+ 'save_final_checkpoint' : True ,
131+
132+ # Single-node multi-GPU setup
133+ 'nproc_per_node' : args .nproc_per_node ,
134+ 'nnodes' : 1 ,
135+ 'node_rank' : 0 ,
136+ 'rdzv_id' : 105 ,
137+ 'rdzv_endpoint' : "127.0.0.1:29500" ,
138+ }
139+
140+
141+ osft (** osft_params )
142+
143+ end_time = time .time ()
144+ duration = end_time - start_time
145+
146+ most_recent_checkpoint = find_most_recent_checkpoint (args .ckpt_output_dir )
147+
148+ print ("=" * 50 )
149+ print ("✅ OSFT Training completed successfully!" )
150+ print (f"⏱️ Duration: { duration / 3600 :.2f} hours" )
151+ print (f"📁 Checkpoints: { args .ckpt_output_dir } /hf_format" )
152+ print (f" Most recent checkpoint: { most_recent_checkpoint } " )
153+ print ()
154+ print ("🎯 Your GPT-OSS 20B model has been successfully adapted!" )
155+ print (" The model now incorporates your domain-specific knowledge" )
156+ print (" while maintaining its original high-quality capabilities." )
157+
158+ except Exception as e :
159+ end_time = time .time ()
160+ duration = end_time - start_time
161+
162+ print ("=" * 50 )
163+ print (f"❌ Training failed after { duration / 60 :.1f} minutes" )
164+ print (f"Error: { e } " )
165+ print ()
166+ print ("💡 Troubleshooting tips:" )
167+ print (" - Reduce --max-tokens-per-gpu if you see OOM errors" )
168+ print (" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.3" )
169+ print (" - Consider reducing batch size further for memory constraints" )
170+ sys .exit (1 )
171+
172+
173+ if __name__ == "__main__" :
174+ main ()
0 commit comments