|
| 1 | +#!/usr/bin/env python3 |
| 2 | +sample_name="OSFT Training Example for HPU" |
| 3 | +""" |
| 4 | +This script demonstrates how to do OSFT (Orthogonal Subspace Fine-Tuning) training on HPU |
| 5 | +using a single-node, multi-GPU setup with training_hub. |
| 6 | +
|
| 7 | +OSFT allows continual training without catastrophic forgetting, making it ideal for: |
| 8 | +- Adapting models to specialized domains (medical, legal, technical) |
| 9 | +- Adding new knowledge without degrading general capabilities |
| 10 | +- Fine-tuning without complex replay mechanisms |
| 11 | +
|
| 12 | +Example usage: |
| 13 | + python osft_hpu_example.py \\ |
| 14 | + --data-path /path/to/data.jsonl \\ |
| 15 | + --ckpt-output-dir /path/to/checkpoints \\ |
| 16 | + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct |
| 17 | +""" |
| 18 | + |
| 19 | +import os |
| 20 | +import sys |
| 21 | +import time |
| 22 | +from datetime import datetime |
| 23 | +import argparse |
| 24 | +import glob |
| 25 | + |
| 26 | +from training_hub import osft |
| 27 | + |
| 28 | +def find_most_recent_checkpoint(output_dir): |
| 29 | + """ |
| 30 | + Find the most recent checkpoint in the training output directory. |
| 31 | + |
| 32 | + Args: |
| 33 | + output_dir (str): Training output directory containing hf_format/ subdirectory |
| 34 | + |
| 35 | + Returns: |
| 36 | + str: Path to the most recent checkpoint |
| 37 | + |
| 38 | + Raises: |
| 39 | + ValueError: If no checkpoints are found |
| 40 | + """ |
| 41 | + # Get all checkpoint directories under hf_format |
| 42 | + checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0") |
| 43 | + checkpoint_dirs = glob.glob(checkpoint_pattern) |
| 44 | + |
| 45 | + if not checkpoint_dirs: |
| 46 | + raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}") |
| 47 | + |
| 48 | + # Find the most recently created checkpoint |
| 49 | + most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime) |
| 50 | + |
| 51 | + return most_recent_checkpoint |
| 52 | + |
| 53 | + |
| 54 | +def main(): |
| 55 | + parser = argparse.ArgumentParser(description=sample_name) |
| 56 | + |
| 57 | + # Required parameters |
| 58 | + parser.add_argument('--data-path', required=True, |
| 59 | + help='Path to training data (JSONL format)') |
| 60 | + parser.add_argument('--ckpt-output-dir', required=True, |
| 61 | + help='Directory to save checkpoints') |
| 62 | + parser.add_argument('--model-path', required=True, |
| 63 | + help='Model path or HuggingFace name') |
| 64 | + |
| 65 | + # Optional overrides |
| 66 | + parser.add_argument('--num-epochs', type=int, default=3, |
| 67 | + help='Number of epochs (default: 3)') |
| 68 | + parser.add_argument('--unfreeze-rank-ratio', type=float, default=0.3, |
| 69 | + help='Unfreeze rank ratio for OSFT (0.0-1.0, default: 0.3)') |
| 70 | + parser.add_argument('--effective-batch-size', type=int, default=128, |
| 71 | + help='effective batch size') |
| 72 | + parser.add_argument('--max-seq-len', type=int, default=4096, |
| 73 | + help='Maximum sequence length (default: 4096)') |
| 74 | + parser.add_argument('--checkpoint-at-epoch', action='store_true', default=False, |
| 75 | + help='Store checkpoint after each epoch') |
| 76 | + parser.add_argument('--max-tokens-per-gpu', type=int, default=32768, |
| 77 | + help='Max tokens per GPU (default: 32768)') |
| 78 | + parser.add_argument('--nproc-per-node', type=int, default=8, |
| 79 | + help='Number of GPUs (default: 8)') |
| 80 | + parser.add_argument('--unmask-messages', action='store_true', default=False, |
| 81 | + help='Unmask messages during training (default: False)') |
| 82 | + parser.add_argument('--learning-rate', type=float, default=5e-6, |
| 83 | + help='Learning rate for training (default: 5e-6)') |
| 84 | + parser.add_argument('--torch-compile', action='store_true', default=False, |
| 85 | + help='Enable torch.compile, hpu only') |
| 86 | + parser.add_argument('--num-chunks', type=int, default=1, |
| 87 | + help='Number of chunks to split dataset into for sequential training') |
| 88 | + |
| 89 | + args = parser.parse_args() |
| 90 | + |
| 91 | + # sample configuration |
| 92 | + print(f"🚀 {sample_name}") |
| 93 | + print("=" * 50) |
| 94 | + print(f"Model: {args.model_path}") |
| 95 | + print(f"Data: {args.data_path}") |
| 96 | + print(f"Output: {args.ckpt_output_dir}") |
| 97 | + print(f"GPUs: {args.nproc_per_node}") |
| 98 | + print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}") |
| 99 | + print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}") |
| 100 | + print() |
| 101 | + print("📝 OSFT Benefits:") |
| 102 | + print(" • Preserve model's strong general capabilities") |
| 103 | + print(" • Add domain-specific knowledge efficiently") |
| 104 | + print(" • No need for complex data mixing or replay buffers") |
| 105 | + print() |
| 106 | + |
| 107 | + # Training configuration |
| 108 | + start_time = time.time() |
| 109 | + |
| 110 | + try: |
| 111 | + osft_params = { |
| 112 | + # Model and data |
| 113 | + 'model_path': args.model_path, |
| 114 | + 'data_path': args.data_path, |
| 115 | + 'ckpt_output_dir': args.ckpt_output_dir, |
| 116 | + |
| 117 | + # OSFT-specific parameters |
| 118 | + 'unfreeze_rank_ratio': args.unfreeze_rank_ratio, |
| 119 | + |
| 120 | + # Training parameters |
| 121 | + 'num_epochs': args.num_epochs, |
| 122 | + 'effective_batch_size': args.effective_batch_size, |
| 123 | + 'learning_rate': args.learning_rate, |
| 124 | + 'max_seq_len': args.max_seq_len, |
| 125 | + 'max_tokens_per_gpu': args.max_tokens_per_gpu, |
| 126 | + |
| 127 | + # Data processing |
| 128 | + 'data_output_dir': "/dev/shm", # Use RAM disk for speed |
| 129 | + 'warmup_steps': 0, |
| 130 | + 'unmask_messages': args.unmask_messages, |
| 131 | + |
| 132 | + # Optimization |
| 133 | + 'use_liger': False, |
| 134 | + 'seed': 42, |
| 135 | + 'lr_scheduler': 'cosine', # Cosine scheduler works well with OSFT |
| 136 | + |
| 137 | + # Checkpointing |
| 138 | + 'checkpoint_at_epoch': args.checkpoint_at_epoch, |
| 139 | + 'save_final_checkpoint': True, |
| 140 | + |
| 141 | + # Single-node multi-GPU setup |
| 142 | + 'nproc_per_node': args.nproc_per_node, |
| 143 | + 'nnodes': 1, |
| 144 | + 'node_rank': 0, |
| 145 | + 'rdzv_id': 103, |
| 146 | + 'rdzv_endpoint': "127.0.0.1:29500", |
| 147 | + |
| 148 | + # HPU specific arguments |
| 149 | + 'device': 'hpu', |
| 150 | + 'torch_compile': args.torch_compile, |
| 151 | + 'num_chunks': args.num_chunks, |
| 152 | + } |
| 153 | + |
| 154 | + |
| 155 | + osft(**osft_params) |
| 156 | + |
| 157 | + end_time = time.time() |
| 158 | + duration = end_time - start_time |
| 159 | + |
| 160 | + most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir) |
| 161 | + |
| 162 | + print("=" * 50) |
| 163 | + print("✅ OSFT Training completed successfully!") |
| 164 | + print(f"⏱️ Duration: {duration/3600:.2f} hours") |
| 165 | + print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format") |
| 166 | + print(f" Most recent checkpoint: {most_recent_checkpoint}") |
| 167 | + print() |
| 168 | + print("🎯 Your model has been successfully adapted!") |
| 169 | + print(" The model now incorporates your domain-specific knowledge") |
| 170 | + print(" while maintaining its original instruction-following abilities.") |
| 171 | + |
| 172 | + except Exception as e: |
| 173 | + end_time = time.time() |
| 174 | + duration = end_time - start_time |
| 175 | + |
| 176 | + print("=" * 50) |
| 177 | + print(f"❌ Training failed after {duration/60:.1f} minutes") |
| 178 | + print(f"Error: {e}") |
| 179 | + print() |
| 180 | + print("💡 Troubleshooting tips:") |
| 181 | + print(" - Reduce --max-tokens-per-gpu if you see OOM errors") |
| 182 | + print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4") |
| 183 | + sys.exit(1) |
| 184 | + |
| 185 | + |
| 186 | +if __name__ == "__main__": |
| 187 | + main() |
| 188 | + |
0 commit comments