Skip to content

Commit a8d9030

Browse files
committed
Add granite training example
1 parent fc2175d commit a8d9030

File tree

3 files changed

+563
-0
lines changed

3 files changed

+563
-0
lines changed

examples/scripts/interpolator.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
A simple model interpolation utility.
3+
4+
This takes two checkpoints of the same model, and outputs a merged checkpoint with the linear interpolation.
5+
6+
Example usage:
7+
python interpolator.py \\
8+
--model-path ibm-granite/granite-3.3-8b-instruct \\
9+
--trained-model-path /path/to/checkpoint
10+
"""
11+
# Standard
12+
import argparse
13+
14+
# Third Party
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
18+
def interpolate_models(
19+
model_path: str,
20+
trained_model_path: str,
21+
trained_model_weight: float = 0.5,
22+
output_model_path: str | None = None,
23+
torch_dtype: str | None = "bfloat16",
24+
) -> str:
25+
if output_model_path is None:
26+
output_model_path = f"{trained_model_path}_interp"
27+
28+
model_kwargs: dict[str, any] = {}
29+
if torch_dtype is not None and torch_dtype != "auto":
30+
model_kwargs["torch_dtype"] = torch_dtype
31+
32+
# load original model
33+
model = AutoModelForCausalLM.from_pretrained(
34+
model_path,
35+
**model_kwargs,
36+
)
37+
state_dict = model.state_dict()
38+
original_model_weight = 1 - trained_model_weight
39+
for key in state_dict.keys():
40+
state_dict[key] = state_dict[key] * original_model_weight
41+
42+
# load trained model
43+
trained_model = AutoModelForCausalLM.from_pretrained(
44+
trained_model_path,
45+
**model_kwargs,
46+
)
47+
trained_state_dict = trained_model.state_dict()
48+
for key in state_dict.keys():
49+
state_dict[key] += trained_state_dict[key] * trained_model_weight
50+
51+
# save interpolated model
52+
model.save_pretrained(output_model_path, state_dict=state_dict)
53+
54+
# copy tokenizer
55+
tokenizer = AutoTokenizer.from_pretrained(model_path)
56+
tokenizer.save_pretrained(output_model_path)
57+
58+
return output_model_path
59+
60+
61+
def parse_arguments():
62+
parser = argparse.ArgumentParser(
63+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
64+
)
65+
parser.add_argument(
66+
"--model-path",
67+
type=str,
68+
required=True,
69+
help="path to the original model",
70+
)
71+
parser.add_argument(
72+
"--trained-model-path",
73+
type=str,
74+
required=True,
75+
help="path to the trained model",
76+
)
77+
parser.add_argument(
78+
"--trained-model-weight",
79+
type=float,
80+
default=0.5,
81+
help="weight for the trained model",
82+
)
83+
parser.add_argument(
84+
"--output-model-path",
85+
type=str,
86+
default=None,
87+
help="path to the output model",
88+
)
89+
parser.add_argument(
90+
"--torch-dtype",
91+
type=str,
92+
default="bfloat16",
93+
help="torch dtype",
94+
)
95+
args = parser.parse_args()
96+
return args
97+
98+
99+
def main():
100+
args = parse_arguments()
101+
model_path: str = args.model_path
102+
trained_model_path: str = args.trained_model_path
103+
trained_model_weight: float = args.trained_model_weight
104+
output_model_path: str | None = args.output_model_path
105+
torch_dtype: str | None = args.torch_dtype
106+
107+
interpolate_models(
108+
model_path,
109+
trained_model_path,
110+
trained_model_weight=trained_model_weight,
111+
output_model_path=output_model_path,
112+
torch_dtype=torch_dtype,
113+
)
114+
115+
116+
if __name__ == "__main__":
117+
main()
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#!/usr/bin/env python3
2+
"""
3+
OSFT Training Example: Granite 3.3 8B Instruct
4+
5+
This script demonstrates OSFT (Orthogonal Subspace Fine-Tuning) training with Granite 3.3 8B Instruct model
6+
using a single-node, multi-GPU setup with training_hub.
7+
8+
OSFT allows continual training without catastrophic forgetting, making it ideal for:
9+
- Adapting instruction-tuned models to new domains
10+
- Adding new knowledge without losing existing capabilities
11+
- Fine-tuning without replay buffers or supplementary datasets
12+
13+
After the training, the script also creates a merged model with linear interpolation.
14+
15+
Example usage:
16+
python osft_granite_example.py \\
17+
--data-path /path/to/data.jsonl \\
18+
--ckpt-output-dir /path/to/checkpoints
19+
"""
20+
21+
import os
22+
import sys
23+
import time
24+
from datetime import datetime
25+
import argparse
26+
import glob
27+
import torch
28+
29+
from training_hub import osft
30+
31+
32+
# =============================================================================
33+
# MODEL CONFIGURATION EXAMPLE FOR OSFT
34+
# =============================================================================
35+
36+
# Derived from generic_7b_example in examples/notebooks/osft_comprehensive_tutorial.ipynb
37+
granite_example = {
38+
"model_name": "Granite 3.3 8B Instruct",
39+
"model_path": "ibm-granite/granite-3.3-8b-instruct", # HuggingFace model name or local path
40+
"example_unfreeze_rank_ratio": 0.3, # Balanced preservation vs adaptation
41+
"example_max_tokens_per_gpu": 10000,
42+
"example_max_seq_len": 4096,
43+
"example_batch_size": 128,
44+
"example_learning_rate": 5e-6,
45+
"notes": "Good baseline for most 7B instruction-tuned models",
46+
}
47+
48+
selected_example = granite_example # Change this to your preferred example
49+
50+
model_name = selected_example['model_name']
51+
default_model_path = selected_example['model_path']
52+
default_unfreeze_rank_ratio = selected_example["example_unfreeze_rank_ratio"]
53+
default_max_tokens_per_gpu = selected_example['example_max_tokens_per_gpu']
54+
default_max_seq_len = selected_example['example_max_seq_len']
55+
default_batch_size = selected_example['example_batch_size']
56+
default_learning_rate = selected_example['example_learning_rate']
57+
default_num_epochs = 3
58+
default_nproc_per_node = torch.cuda.device_count() if torch.cuda.is_available() else 0
59+
default_model_weight = 0.5
60+
61+
# =============================================================================
62+
# COMPLETE OSFT PARAMETER CONFIGURATION
63+
# =============================================================================
64+
65+
# Experiment identification
66+
experiment_name = "osft_granite_example"
67+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
68+
full_experiment_name = f"{experiment_name}_{timestamp}"
69+
70+
# data_output_dir=f"data/{full_experiment_name}" # Directory for processed data
71+
data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed)
72+
73+
74+
def find_most_recent_checkpoint(output_dir):
75+
"""
76+
Find the most recent checkpoint in the training output directory.
77+
78+
Args:
79+
output_dir (str): Training output directory containing hf_format/ subdirectory
80+
81+
Returns:
82+
str: Path to the most recent checkpoint
83+
84+
Raises:
85+
ValueError: If no checkpoints are found
86+
"""
87+
# Get all checkpoint directories under hf_format
88+
checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
89+
checkpoint_dirs = glob.glob(checkpoint_pattern)
90+
91+
if not checkpoint_dirs:
92+
raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
93+
94+
# Find the most recently created checkpoint
95+
most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
96+
97+
return most_recent_checkpoint
98+
99+
100+
def main():
101+
parser = argparse.ArgumentParser(description=f'OSFT Training Example: {model_name}')
102+
103+
# Required parameters
104+
parser.add_argument('--data-path', required=True,
105+
help='Path to training data (JSONL format)')
106+
parser.add_argument('--ckpt-output-dir', required=True,
107+
help='Directory to save checkpoints')
108+
109+
# Optional overrides
110+
parser.add_argument('--model-path', default=default_model_path,
111+
help=f'Model path or HuggingFace name (default: {default_model_path})')
112+
parser.add_argument('--num-epochs', type=int, default=default_num_epochs,
113+
help=f'Number of training epochs (default: {default_num_epochs})')
114+
parser.add_argument('--unfreeze-rank-ratio', type=float, default=default_unfreeze_rank_ratio,
115+
help=f'Unfreeze rank ratio for OSFT (0.0-1.0, default: {default_unfreeze_rank_ratio})')
116+
parser.add_argument('--max-tokens-per-gpu', type=int, default=default_max_tokens_per_gpu,
117+
help=f'Max tokens per GPU (default: {default_max_tokens_per_gpu})')
118+
parser.add_argument('--nproc-per-node', type=int, default=default_nproc_per_node,
119+
help=f'Number of GPUs (default: {default_nproc_per_node})')
120+
parser.add_argument('--learning-rate', type=float, default=default_learning_rate,
121+
help=f'Learning rate for training (default: {default_learning_rate})')
122+
parser.add_argument('--unmask-messages', action='store_true', default=False,
123+
help='Unmask messages during training (default: False)')
124+
parser.add_argument('--batch-size', type=int, default=default_batch_size,
125+
help=f'Effective batch size for training (default: {default_batch_size})')
126+
parser.add_argument('--max-seq-len', type=int, default=default_max_seq_len,
127+
help=f'Max sequence length (default: {default_max_seq_len})')
128+
parser.add_argument('--model-weight', type=float, default=default_model_weight,
129+
help=f'Weight for trained model for interpolation (0.0-1.0, default: {default_model_weight})')
130+
131+
args = parser.parse_args()
132+
133+
if args.nproc_per_node < 4:
134+
raise ValueError("NPROC_PER_NODE must be larger than or equal to 4")
135+
136+
# Granite 3.3 8B Instruct OSFT configuration
137+
print(f"🚀 OSFT Training: {model_name}")
138+
print("=" * 50)
139+
print(f"Model: {args.model_path}")
140+
print(f"Data: {args.data_path}")
141+
print(f"Output: {args.ckpt_output_dir}")
142+
print(f"GPUs: {args.nproc_per_node}")
143+
print(f"Unfreeze Rank Ratio: {args.unfreeze_rank_ratio}")
144+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
145+
print(f"Epochs: {args.num_epochs}")
146+
print(f"Batch size: {args.batch_size}")
147+
print(f"Learning rate: {args.learning_rate}")
148+
print(f"Max sequence length: {args.max_seq_len:,}")
149+
print(f"Model weight: {args.model_weight}")
150+
print()
151+
print("📝 Note: OSFT enables continual learning without replay buffers")
152+
print(" The model will adapt to new data while preserving existing capabilities")
153+
print()
154+
155+
# Training configuration optimized for Granite 3.3 8B Instruct with OSFT
156+
start_time = time.time()
157+
158+
try:
159+
result = osft(
160+
# Model and data
161+
model_path=args.model_path,
162+
data_path=args.data_path,
163+
ckpt_output_dir=args.ckpt_output_dir,
164+
165+
# OSFT-specific parameters
166+
unfreeze_rank_ratio=args.unfreeze_rank_ratio, # Controls preservation vs adaptation
167+
168+
# Training parameters optimized for Granite 3.3 8B Instruct
169+
num_epochs=args.num_epochs,
170+
effective_batch_size=args.batch_size, # Smaller batch for efficient model
171+
learning_rate=args.learning_rate, # Very low LR for smaller but dense model
172+
max_seq_len=args.max_seq_len,
173+
max_tokens_per_gpu=args.max_tokens_per_gpu,
174+
175+
# Data processing
176+
data_output_dir=data_output_dir,
177+
warmup_steps=0,
178+
unmask_messages=args.unmask_messages,
179+
180+
# Optimization
181+
use_liger=True, # Enable Liger kernels for efficiency
182+
seed=42,
183+
lr_scheduler='cosine', # Cosine scheduler works well with OSFT
184+
185+
# Checkpointing
186+
checkpoint_at_epoch=True,
187+
save_final_checkpoint=True,
188+
189+
# Single-node multi-GPU setup
190+
nproc_per_node=args.nproc_per_node,
191+
nnodes=1,
192+
node_rank=0,
193+
rdzv_id=102,
194+
rdzv_endpoint="127.0.0.1:29500",
195+
)
196+
197+
end_time = time.time()
198+
duration = end_time - start_time
199+
200+
most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir)
201+
202+
print("=" * 50)
203+
print("✅ OSFT Training completed successfully!")
204+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
205+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format")
206+
print(f" Most recent checkpoint: {most_recent_checkpoint}")
207+
print()
208+
print("💡 Your model has been adapted to the new domain while preserving")
209+
print(" its original instruction-following capabilities!")
210+
211+
trained_model_weight = args.model_weight
212+
if 0.0 < trained_model_weight and trained_model_weight < 1.0:
213+
from interpolator import interpolate_models
214+
215+
interp_model_path = interpolate_models(args.model_path, most_recent_checkpoint, trained_model_weight=trained_model_weight)
216+
217+
print("=" * 50)
218+
print("✅ Interpolation completed successfully!")
219+
print(f" Interpolated model checkpoint: {interp_model_path}")
220+
221+
except Exception as e:
222+
end_time = time.time()
223+
duration = end_time - start_time
224+
225+
print("=" * 50)
226+
print(f"❌ Training failed after {duration/60:.1f} minutes")
227+
print(f"Error: {e}")
228+
print()
229+
print("💡 Troubleshooting tips:")
230+
print(" - Reduce --max-tokens-per-gpu if you see OOM errors")
231+
print(" - For domain adaptation, try --unfreeze-rank-ratio between 0.2-0.4")
232+
sys.exit(1)
233+
234+
235+
if __name__ == "__main__":
236+
main()
237+

0 commit comments

Comments
 (0)