Skip to content

Commit 33ed975

Browse files
committed
Add HPU OSFT example
Signed-off-by: Sergey Plotnikov <[email protected]>
1 parent a477b3b commit 33ed975

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#!/usr/bin/env python3
2+
sample_name="OSFT Training Example for HPU"
3+
"""
4+
This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU
5+
using a single-node, multi-GPU setup with training_hub.
6+
7+
OSFT allows continual training without catastrophic forgetting, making it ideal for:
8+
- Adapting models to specialized domains (medical, legal, technical)
9+
- Adding new knowledge without degrading general capabilities
10+
- Fine-tuning without complex replay mechanisms
11+
12+
Example usage:
13+
python osft_hpu_example.py \\
14+
--data-path /path/to/data.jsonl \\
15+
--ckpt-output-dir /path/to/checkpoints \\
16+
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct
17+
"""
18+
19+
import os
20+
import sys
21+
import time
22+
from datetime import datetime
23+
import argparse
24+
import glob
25+
26+
from training_hub import osft
27+
28+
def find_most_recent_checkpoint(output_dir):
29+
"""
30+
Find the most recent checkpoint in the training output directory.
31+
32+
Args:
33+
output_dir (str): Training output directory containing hf_format/ subdirectory
34+
35+
Returns:
36+
str: Path to the most recent checkpoint
37+
38+
Raises:
39+
ValueError: If no checkpoints are found
40+
"""
41+
# Get all checkpoint directories under hf_format
42+
checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
43+
checkpoint_dirs = glob.glob(checkpoint_pattern)
44+
45+
if not checkpoint_dirs:
46+
raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
47+
48+
# Find the most recently created checkpoint
49+
most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
50+
51+
return most_recent_checkpoint
52+
53+
54+
def main():
55+
parser = argparse.ArgumentParser(description=sample_name)
56+
57+
# Required parameters
58+
parser.add_argument('--data-path', required=True,
59+
help='Path to training data (JSONL format)')
60+
parser.add_argument('--ckpt-output-dir', required=True,
61+
help='Directory to save checkpoints')
62+
parser.add_argument('--model-path', required=True,
63+
help='Model path or HuggingFace name')
64+
65+
# Optional overrides
66+
parser.add_argument('--num-epochs', type=int, default=3,
67+
help='Number of epochs (default: 3)')
68+
parser.add_argument('--unfreeze-rank-ratio', type=float, default=0.3,
69+
help='Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.3)')
70+
parser.add_argument('--effective-batch-size', type=int, default=128,
71+
help='effective batch size')
72+
parser.add_argument('--max-seq-len', type=int, default=4096,
73+
help='Maximum sequence length (default: 4096)')
74+
parser.add_argument('--checkpoint-at-epoch', action='store_true', default=False,
75+
help='Store checkpoint after each epoch')
76+
parser.add_argument('--max-tokens-per-gpu', type=int, default=32768,
77+
help='Max tokens per GPU (default: 32768)')
78+
parser.add_argument('--nproc-per-node', type=int, default=8,
79+
help='Number of GPUs (default: 8)')
80+
parser.add_argument('--unmask-messages', action='store_true', default=False,
81+
help='Unmask messages during training (default: False)')
82+
parser.add_argument('--learning-rate', type=float, default=5e-6,
83+
help='Learning rate for training (default: 5e-6)')
84+
parser.add_argument('--torch-compile', action='store_true', default=False,
85+
help='Enable torch.compile, hpu only')
86+
parser.add_argument('--num-chunks', type=int, default=1,
87+
help='Number of chunks to split dataset into for sequential training')
88+
89+
args = parser.parse_args()
90+
91+
# sample configuration
92+
print(f"🚀 {sample_name}")
93+
print("=" * 50)
94+
print(f"Model: {args.model_path}")
95+
print(f"Data: {args.data_path}")
96+
print(f"Output: {args.ckpt_output_dir}")
97+
print(f"GPUs: {args.nproc_per_node}")
98+
print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}")
99+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
100+
print()
101+
print("📝 OSFT Benefits:")
102+
print(" • Preserve model's strong general capabilities")
103+
print(" • Add domain-specific knowledge efficiently")
104+
print(" • No need for complex data mixing or replay buffers")
105+
print()
106+
107+
# Training configuration
108+
start_time = time.time()
109+
110+
try:
111+
osft_params = {
112+
# Model and data
113+
'model_path': args.model_path,
114+
'data_path': args.data_path,
115+
'ckpt_output_dir': args.ckpt_output_dir,
116+
117+
# OSFT-specific parameters
118+
'unfreeze_rank_ratio': args.unfreeze_rank_ratio,
119+
120+
# Training parameters
121+
'num_epochs': args.num_epochs,
122+
'effective_batch_size': args.effective_batch_size,
123+
'learning_rate': args.learning_rate,
124+
'max_seq_len': args.max_seq_len,
125+
'max_tokens_per_gpu': args.max_tokens_per_gpu,
126+
127+
# Data processing
128+
'data_output_dir': "/dev/shm", # Use RAM disk for speed
129+
'warmup_steps': 0,
130+
'unmask_messages': args.unmask_messages,
131+
132+
# Optimization
133+
'use_liger': False,
134+
'seed': 42,
135+
'lr_scheduler': 'cosine', # Cosine scheduler works well with OSFT
136+
137+
# Checkpointing
138+
'checkpoint_at_epoch': args.checkpoint_at_epoch,
139+
'save_final_checkpoint': True,
140+
141+
# Single-node multi-GPU setup
142+
'nproc_per_node': args.nproc_per_node,
143+
'nnodes': 1,
144+
'node_rank': 0,
145+
'rdzv_id': 103,
146+
'rdzv_endpoint': "127.0.0.1:29500",
147+
148+
# HPU specific arguments
149+
'device': 'hpu',
150+
'torch_compile': args.torch_compile,
151+
'num_chunks': args.num_chunks,
152+
}
153+
154+
155+
osft(**osft_params)
156+
157+
end_time = time.time()
158+
duration = end_time - start_time
159+
160+
most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir)
161+
162+
print("=" * 50)
163+
print("✅ OSFT Training completed successfully!")
164+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
165+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format")
166+
print(f" Most recent checkpoint: {most_recent_checkpoint}")
167+
print()
168+
print("🎯 Your model has been successfully adapted!")
169+
print(" The model now incorporates your domain-specific knowledge")
170+
print(" while maintaining its original instruction-following abilities.")
171+
172+
except Exception as e:
173+
end_time = time.time()
174+
duration = end_time - start_time
175+
176+
print("=" * 50)
177+
print(f"❌ Training failed after {duration/60:.1f} minutes")
178+
print(f"Error: {e}")
179+
print()
180+
print("💡 Troubleshooting tips:")
181+
print(" - Reduce --max-tokens-per-gpu if you see OOM errors")
182+
print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4")
183+
sys.exit(1)
184+
185+
186+
if __name__ == "__main__":
187+
main()
188+

0 commit comments

Comments
 (0)