|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +OSFT Training Example: Granite 3.3 8B Instruct |
| 4 | +
|
| 5 | +This script demonstrates OSFT (Orthogonal Subspace Fine-Tuning) training with Granite 3.3 8B Instruct model |
| 6 | +using a single-node, multi-GPU setup with training_hub. |
| 7 | +
|
| 8 | +OSFT allows continual training without catastrophic forgetting, making it ideal for: |
| 9 | +- Adapting instruction-tuned models to new domains |
| 10 | +- Adding new knowledge without losing existing capabilities |
| 11 | +- Fine-tuning without replay buffers or supplementary datasets |
| 12 | +
|
| 13 | +After the training, the script also creates a merged model with linear interpolation. |
| 14 | +
|
| 15 | +Example usage: |
| 16 | + python osft_granite_example.py \\ |
| 17 | + --data-path /path/to/data.jsonl \\ |
| 18 | + --ckpt-output-dir /path/to/checkpoints |
| 19 | +""" |
| 20 | + |
| 21 | +import os |
| 22 | +import sys |
| 23 | +import time |
| 24 | +from datetime import datetime |
| 25 | +import argparse |
| 26 | +import glob |
| 27 | +import torch |
| 28 | + |
| 29 | +from training_hub import osft |
| 30 | + |
| 31 | + |
| 32 | +# ============================================================================= |
| 33 | +# MODEL CONFIGURATION EXAMPLE FOR OSFT |
| 34 | +# ============================================================================= |
| 35 | + |
| 36 | +# Derived from generic_7b_example in examples/notebooks/osft_comprehensive_tutorial.ipynb |
| 37 | +granite_example = { |
| 38 | + "model_name": "Granite 3.3 8B Instruct", |
| 39 | + "model_path": "ibm-granite/granite-3.3-8b-instruct", # HuggingFace model name or local path |
| 40 | + "example_unfreeze_rank_ratio": 0.3, # Balanced preservation vs adaptation |
| 41 | + "example_max_tokens_per_gpu": 10000, |
| 42 | + "example_max_seq_len": 4096, |
| 43 | + "example_batch_size": 128, |
| 44 | + "example_learning_rate": 5e-6, |
| 45 | + "notes": "Good baseline for most 7B instruction-tuned models", |
| 46 | +} |
| 47 | + |
| 48 | +selected_example = granite_example # Change this to your preferred example |
| 49 | + |
| 50 | +model_name = selected_example['model_name'] |
| 51 | +default_model_path = selected_example['model_path'] |
| 52 | +default_unfreeze_rank_ratio = selected_example["example_unfreeze_rank_ratio"] |
| 53 | +default_max_tokens_per_gpu = selected_example['example_max_tokens_per_gpu'] |
| 54 | +default_max_seq_len = selected_example['example_max_seq_len'] |
| 55 | +default_batch_size = selected_example['example_batch_size'] |
| 56 | +default_learning_rate = selected_example['example_learning_rate'] |
| 57 | +default_num_epochs = 3 |
| 58 | +default_nproc_per_node = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| 59 | +default_model_weight = 0.5 |
| 60 | + |
| 61 | +# ============================================================================= |
| 62 | +# COMPLETE OSFT PARAMETER CONFIGURATION |
| 63 | +# ============================================================================= |
| 64 | + |
| 65 | +# Experiment identification |
| 66 | +experiment_name = "osft_granite_example" |
| 67 | +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 68 | +full_experiment_name = f"{experiment_name}_{timestamp}" |
| 69 | + |
| 70 | +# data_output_dir=f"data/{full_experiment_name}" # Directory for processed data |
| 71 | +data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed) |
| 72 | + |
| 73 | + |
| 74 | +def find_most_recent_checkpoint(output_dir): |
| 75 | + """ |
| 76 | + Find the most recent checkpoint in the training output directory. |
| 77 | + |
| 78 | + Args: |
| 79 | + output_dir (str): Training output directory containing hf_format/ subdirectory |
| 80 | + |
| 81 | + Returns: |
| 82 | + str: Path to the most recent checkpoint |
| 83 | + |
| 84 | + Raises: |
| 85 | + ValueError: If no checkpoints are found |
| 86 | + """ |
| 87 | + # Get all checkpoint directories under hf_format |
| 88 | + checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0") |
| 89 | + checkpoint_dirs = glob.glob(checkpoint_pattern) |
| 90 | + |
| 91 | + if not checkpoint_dirs: |
| 92 | + raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}") |
| 93 | + |
| 94 | + # Find the most recently created checkpoint |
| 95 | + most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime) |
| 96 | + |
| 97 | + return most_recent_checkpoint |
| 98 | + |
| 99 | + |
| 100 | +def main(): |
| 101 | + parser = argparse.ArgumentParser(description=f'OSFT Training Example: {model_name}') |
| 102 | + |
| 103 | + # Required parameters |
| 104 | + parser.add_argument('--data-path', required=True, |
| 105 | + help='Path to training data (JSONL format)') |
| 106 | + parser.add_argument('--ckpt-output-dir', required=True, |
| 107 | + help='Directory to save checkpoints') |
| 108 | + |
| 109 | + # Optional overrides |
| 110 | + parser.add_argument('--model-path', default=default_model_path, |
| 111 | + help=f'Model path or HuggingFace name (default: {default_model_path})') |
| 112 | + parser.add_argument('--num-epochs', type=int, default=default_num_epochs, |
| 113 | + help=f'Number of training epochs (default: {default_num_epochs})') |
| 114 | + parser.add_argument('--unfreeze-rank-ratio', type=float, default=default_unfreeze_rank_ratio, |
| 115 | + help=f'Unfreeze rank ratio for OSFT (0.0-1.0, default: {default_unfreeze_rank_ratio})') |
| 116 | + parser.add_argument('--max-tokens-per-gpu', type=int, default=default_max_tokens_per_gpu, |
| 117 | + help=f'Max tokens per GPU (default: {default_max_tokens_per_gpu})') |
| 118 | + parser.add_argument('--nproc-per-node', type=int, default=default_nproc_per_node, |
| 119 | + help=f'Number of GPUs (default: {default_nproc_per_node})') |
| 120 | + parser.add_argument('--learning-rate', type=float, default=default_learning_rate, |
| 121 | + help=f'Learning rate for training (default: {default_learning_rate})') |
| 122 | + parser.add_argument('--unmask-messages', action='store_true', default=False, |
| 123 | + help='Unmask messages during training (default: False)') |
| 124 | + parser.add_argument('--batch-size', type=int, default=default_batch_size, |
| 125 | + help=f'Effective batch size for training (default: {default_batch_size})') |
| 126 | + parser.add_argument('--max-seq-len', type=int, default=default_max_seq_len, |
| 127 | + help=f'Max sequence length (default: {default_max_seq_len})') |
| 128 | + parser.add_argument('--model-weight', type=float, default=default_model_weight, |
| 129 | + help=f'Weight for trained model for interpolation (0.0-1.0, default: {default_model_weight})') |
| 130 | + |
| 131 | + args = parser.parse_args() |
| 132 | + |
| 133 | + if args.nproc_per_node < 4: |
| 134 | + raise ValueError("NPROC_PER_NODE must be larger than or equal to 4") |
| 135 | + |
| 136 | + # Granite 3.3 8B Instruct OSFT configuration |
| 137 | + print(f"🚀 OSFT Training: {model_name}") |
| 138 | + print("=" * 50) |
| 139 | + print(f"Model: {args.model_path}") |
| 140 | + print(f"Data: {args.data_path}") |
| 141 | + print(f"Output: {args.ckpt_output_dir}") |
| 142 | + print(f"GPUs: {args.nproc_per_node}") |
| 143 | + print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}") |
| 144 | + print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}") |
| 145 | + print(f"Epochs: {args.num_epochs}") |
| 146 | + print(f"Batch size: {args.batch_size}") |
| 147 | + print(f"Learning rate: {args.learning_rate}") |
| 148 | + print(f"Max sequence length: {args.max_seq_len:,}") |
| 149 | + print(f"Model weight: {args.model_weight}") |
| 150 | + print() |
| 151 | + print("📝 Note: OSFT enables continual learning without replay buffers") |
| 152 | + print(" The model will adapt to new data while preserving existing capabilities") |
| 153 | + print() |
| 154 | + |
| 155 | + # Training configuration optimized for Granite 3.3 8B Instruct with OSFT |
| 156 | + start_time = time.time() |
| 157 | + |
| 158 | + try: |
| 159 | + result = osft( |
| 160 | + # Model and data |
| 161 | + model_path=args.model_path, |
| 162 | + data_path=args.data_path, |
| 163 | + ckpt_output_dir=args.ckpt_output_dir, |
| 164 | + |
| 165 | + # OSFT-specific parameters |
| 166 | + unfreeze_rank_ratio=args.unfreeze_rank_ratio, # Controls preservation vs adaptation |
| 167 | + |
| 168 | + # Training parameters optimized for Granite 3.3 8B Instruct |
| 169 | + num_epochs=args.num_epochs, |
| 170 | + effective_batch_size=args.batch_size, # Smaller batch for efficient model |
| 171 | + learning_rate=args.learning_rate, # Very low LR for smaller but dense model |
| 172 | + max_seq_len=args.max_seq_len, |
| 173 | + max_tokens_per_gpu=args.max_tokens_per_gpu, |
| 174 | + |
| 175 | + # Data processing |
| 176 | + data_output_dir=data_output_dir, |
| 177 | + warmup_steps=0, |
| 178 | + unmask_messages=args.unmask_messages, |
| 179 | + |
| 180 | + # Optimization |
| 181 | + use_liger=True, # Enable Liger kernels for efficiency |
| 182 | + seed=42, |
| 183 | + lr_scheduler='cosine', # Cosine scheduler works well with OSFT |
| 184 | + |
| 185 | + # Checkpointing |
| 186 | + checkpoint_at_epoch=True, |
| 187 | + save_final_checkpoint=True, |
| 188 | + |
| 189 | + # Single-node multi-GPU setup |
| 190 | + nproc_per_node=args.nproc_per_node, |
| 191 | + nnodes=1, |
| 192 | + node_rank=0, |
| 193 | + rdzv_id=102, |
| 194 | + rdzv_endpoint="127.0.0.1:29500", |
| 195 | + ) |
| 196 | + |
| 197 | + end_time = time.time() |
| 198 | + duration = end_time - start_time |
| 199 | + |
| 200 | + most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir) |
| 201 | + |
| 202 | + print("=" * 50) |
| 203 | + print("✅ OSFT Training completed successfully!") |
| 204 | + print(f"⏱️ Duration: {duration/3600:.2f} hours") |
| 205 | + print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format") |
| 206 | + print(f" Most recent checkpoint: {most_recent_checkpoint}") |
| 207 | + print() |
| 208 | + print("💡 Your model has been adapted to the new domain while preserving") |
| 209 | + print(" its original instruction-following capabilities!") |
| 210 | + |
| 211 | + trained_model_weight = args.model_weight |
| 212 | + if 0.0 < trained_model_weight and trained_model_weight < 1.0: |
| 213 | + from interpolator import interpolate_models |
| 214 | + |
| 215 | + interp_model_path = interpolate_models(args.model_path, most_recent_checkpoint, trained_model_weight=trained_model_weight) |
| 216 | + |
| 217 | + print("=" * 50) |
| 218 | + print("✅ Interpolation completed successfully!") |
| 219 | + print(f" Interpolated model checkpoint: {interp_model_path}") |
| 220 | + |
| 221 | + except Exception as e: |
| 222 | + end_time = time.time() |
| 223 | + duration = end_time - start_time |
| 224 | + |
| 225 | + print("=" * 50) |
| 226 | + print(f"❌ Training failed after {duration/60:.1f} minutes") |
| 227 | + print(f"Error: {e}") |
| 228 | + print() |
| 229 | + print("💡 Troubleshooting tips:") |
| 230 | + print(" - Reduce --max-tokens-per-gpu if you see OOM errors") |
| 231 | + print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4") |
| 232 | + sys.exit(1) |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + main() |
| 237 | + |
0 commit comments