Skip to content

Commit ce5903a

Browse files
authored
Add Granite 4 SFT example (#20)
* Add Granite 4 SFT example * address bot comments * disable auto-resumption
1 parent 0c67914 commit ce5903a

File tree

4 files changed

+293
-43
lines changed

4 files changed

+293
-43
lines changed

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The SFT algorithm supports training language models on supervised datasets with
2828
- [SFT with Phi 4 Mini](scripts/sft_phi_example.py) - Single-node multi-GPU training example with Phi 4 Mini Instruct
2929
- [SFT with GPT-OSS 20B](scripts/sft_gpt_oss_example.py) - Single-node multi-GPU training example with GPT-OSS 20B
3030
- [SFT with Granite 3.3 8B](scripts/sft_granite_example.py) - Single-node multi-GPU training example with Granite 3.3 8B Instruct
31+
- [SFT with Granite 4.0](scripts/sft_granite4_example.py) - Single-node multi-GPU training example with Granite 4.0 models
3132

3233
**Quick Example:**
3334
```python

examples/scripts/osft_granite_example.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
granite_example = {
3838
"model_name": "Granite 3.3 8B Instruct",
3939
"model_path": "ibm-granite/granite-3.3-8b-instruct", # HuggingFace model name or local path
40+
"example_min_nproc_per_node": 2,
4041
"example_unfreeze_rank_ratio": 0.3, # Balanced preservation vs adaptation
4142
"example_max_tokens_per_gpu": 10000,
4243
"example_max_seq_len": 4096,
@@ -49,11 +50,12 @@
4950

5051
model_name = selected_example['model_name']
5152
default_model_path = selected_example['model_path']
52-
default_unfreeze_rank_ratio = selected_example["example_unfreeze_rank_ratio"]
53-
default_max_tokens_per_gpu = selected_example['example_max_tokens_per_gpu']
54-
default_max_seq_len = selected_example['example_max_seq_len']
55-
default_batch_size = selected_example['example_batch_size']
56-
default_learning_rate = selected_example['example_learning_rate']
53+
example_min_nproc_per_node = selected_example['example_min_nproc_per_node']
54+
example_unfreeze_rank_ratio = selected_example['example_unfreeze_rank_ratio']
55+
example_max_tokens_per_gpu = selected_example['example_max_tokens_per_gpu']
56+
example_max_seq_len = selected_example['example_max_seq_len']
57+
example_batch_size = selected_example['example_batch_size']
58+
example_learning_rate = selected_example['example_learning_rate']
5759
default_num_epochs = 3
5860
default_nproc_per_node = torch.cuda.device_count() if torch.cuda.is_available() else 0
5961
default_model_weight = 0.5
@@ -67,8 +69,8 @@
6769
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
6870
full_experiment_name = f"{experiment_name}_{timestamp}"
6971

70-
# data_output_dir=f"data/{full_experiment_name}" # Directory for processed data
71-
data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed)
72+
data_output_dir=f"data/{full_experiment_name}" # Directory for processed data
73+
# data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed)
7274

7375

7476
def find_most_recent_checkpoint(output_dir):
@@ -111,27 +113,27 @@ def main():
111113
help=f'Model path or HuggingFace name (default: {default_model_path})')
112114
parser.add_argument('--num-epochs', type=int, default=default_num_epochs,
113115
help=f'Number of training epochs (default: {default_num_epochs})')
114-
parser.add_argument('--unfreeze-rank-ratio', type=float, default=default_unfreeze_rank_ratio,
115-
help=f'Unfreeze rank ratio for OSFT (0.0-1.0, default: {default_unfreeze_rank_ratio})')
116-
parser.add_argument('--max-tokens-per-gpu', type=int, default=default_max_tokens_per_gpu,
117-
help=f'Max tokens per GPU (default: {default_max_tokens_per_gpu})')
118116
parser.add_argument('--nproc-per-node', type=int, default=default_nproc_per_node,
119117
help=f'Number of GPUs (default: {default_nproc_per_node})')
120-
parser.add_argument('--learning-rate', type=float, default=default_learning_rate,
121-
help=f'Learning rate for training (default: {default_learning_rate})')
122118
parser.add_argument('--unmask-messages', action='store_true', default=False,
123119
help='Unmask all non-system messages during training, otherwise only unmasks assistant messages (default: False)')
124-
parser.add_argument('--batch-size', type=int, default=default_batch_size,
125-
help=f'Effective batch size for training (default: {default_batch_size})')
126-
parser.add_argument('--max-seq-len', type=int, default=default_max_seq_len,
127-
help=f'Max sequence length (default: {default_max_seq_len})')
120+
parser.add_argument('--unfreeze-rank-ratio', type=float, default=example_unfreeze_rank_ratio,
121+
help=f'Unfreeze rank ratio for OSFT (0.0-1.0, default: {example_unfreeze_rank_ratio})')
122+
parser.add_argument('--max-tokens-per-gpu', type=int, default=example_max_tokens_per_gpu,
123+
help=f'Max tokens per GPU (default: {example_max_tokens_per_gpu})')
124+
parser.add_argument('--learning-rate', type=float, default=example_learning_rate,
125+
help=f'Learning rate for training (default: {example_learning_rate})')
126+
parser.add_argument('--batch-size', type=int, default=example_batch_size,
127+
help=f'Effective batch size for training (default: {example_batch_size})')
128+
parser.add_argument('--max-seq-len', type=int, default=example_max_seq_len,
129+
help=f'Max sequence length (default: {example_max_seq_len})')
128130
parser.add_argument('--model-weight', type=float, default=default_model_weight,
129131
help=f'Weight for trained model for interpolation (0.0-1.0, default: {default_model_weight})')
130132

131133
args = parser.parse_args()
132134

133-
if args.nproc_per_node < 4:
134-
raise ValueError("NPROC_PER_NODE must be larger than or equal to 4")
135+
if args.nproc_per_node < example_min_nproc_per_node:
136+
print(f"💡 Try --nproc-per-node {example_min_nproc_per_node} or larger if you see OOM errors")
135137

136138
# Granite 3.3 8B Instruct OSFT configuration
137139
print(f"🚀 OSFT Training: {model_name}")
@@ -156,7 +158,7 @@ def main():
156158
start_time = time.time()
157159

158160
try:
159-
result = osft(
161+
osft(
160162
# Model and data
161163
model_path=args.model_path,
162164
data_path=args.data_path,
@@ -189,9 +191,9 @@ def main():
189191
# Single-node multi-GPU setup
190192
nproc_per_node=args.nproc_per_node,
191193
nnodes=1,
192-
node_rank=0,
193-
rdzv_id=102,
194-
rdzv_endpoint="127.0.0.1:29500",
194+
# node_rank=0,
195+
# rdzv_id=102,
196+
# rdzv_endpoint="127.0.0.1:29500",
195197
)
196198

197199
end_time = time.time()
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#!/usr/bin/env python3
2+
"""
3+
SFT Training Example: Granite 4.0
4+
5+
This script demonstrates SFT training with Granite 4.0 models
6+
using a single-node, multi-GPU setup with training_hub.
7+
8+
After the training, the script also creates a merged model with linear interpolation.
9+
10+
Example usage:
11+
python sft_granite4_example.py \\
12+
--data-path /path/to/data.jsonl \\
13+
--ckpt-output-dir /path/to/checkpoints
14+
"""
15+
16+
import os
17+
import sys
18+
import time
19+
import glob
20+
from datetime import datetime
21+
import argparse
22+
import torch
23+
24+
from instructlab.training import FSDPOptions
25+
from training_hub import sft
26+
27+
28+
# =============================================================================
29+
# MODEL CONFIGURATION EXAMPLE
30+
# =============================================================================
31+
32+
# Derived from generic_7b_example in examples/notebooks/sft_comprehensive_tutorial.ipynb
33+
granite4_example_template = {
34+
"example_max_tokens_per_gpu": 25000,
35+
"example_max_seq_len": 20000,
36+
"example_batch_size": 256,
37+
"example_learning_rate": 2e-5,
38+
"notes": "Good baseline for most 7B instruction-tuned models",
39+
}
40+
41+
granite4hs_example = {
42+
**granite4_example_template,
43+
"model_name": "Granite-4.0-H-Small",
44+
"model_path": "ibm-granite/granite-4.0-h-small", # HuggingFace model name or local path
45+
"example_min_nproc_per_node": 8,
46+
"example_batch_size": 128,
47+
"kwargs": {
48+
"fsdp_options": FSDPOptions(cpu_offload_params=True),
49+
},
50+
}
51+
granite4ht_example = {
52+
**granite4_example_template,
53+
"model_name": "Granite-4.0-H-Tiny",
54+
"model_path": "ibm-granite/granite-4.0-h-tiny", # HuggingFace model name or local path
55+
"example_min_nproc_per_node": 2,
56+
}
57+
granite4hm_example = {
58+
**granite4_example_template,
59+
"model_name": "Granite-4.0-H-Micro",
60+
"model_path": "ibm-granite/granite-4.0-h-micro", # HuggingFace model name or local path
61+
"example_min_nproc_per_node": 2,
62+
}
63+
granite4m_example = {
64+
**granite4_example_template,
65+
"model_name": "Granite-4.0-Micro",
66+
"model_path": "ibm-granite/granite-4.0-micro", # HuggingFace model name or local path
67+
"example_min_nproc_per_node": 2,
68+
}
69+
70+
selected_example = granite4hs_example # Change this to your preferred example
71+
# selected_example = granite4ht_example # Change this to your preferred example
72+
# selected_example = granite4hm_example # Change this to your preferred example
73+
# selected_example = granite4m_example # Change this to your preferred example
74+
75+
model_name = selected_example['model_name']
76+
default_model_path = selected_example['model_path']
77+
example_min_nproc_per_node = selected_example['example_min_nproc_per_node']
78+
example_max_tokens_per_gpu = selected_example['example_max_tokens_per_gpu']
79+
example_max_seq_len = selected_example['example_max_seq_len']
80+
example_batch_size = selected_example['example_batch_size']
81+
example_learning_rate = selected_example['example_learning_rate']
82+
kwargs = selected_example.get('kwargs', {})
83+
default_num_epochs = 3
84+
default_nproc_per_node = torch.cuda.device_count() if torch.cuda.is_available() else 0
85+
default_model_weight = 0.5
86+
87+
# =============================================================================
88+
# COMPLETE SFT PARAMETER CONFIGURATION
89+
# =============================================================================
90+
91+
# Experiment identification
92+
experiment_name = "sft_granite4_example"
93+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
94+
full_experiment_name = f"{experiment_name}_{timestamp}"
95+
96+
data_output_dir=f"data/{full_experiment_name}" # Directory for processed data
97+
# data_output_dir=f"/dev/shm/data/{full_experiment_name}" # Directory for processed data (RAM disk for speed)
98+
99+
100+
# Copied from examples/scripts/osft_continual_learning_example.py
101+
def find_most_recent_checkpoint(output_dir):
102+
"""
103+
Find the most recent checkpoint in the training output directory.
104+
105+
Args:
106+
output_dir (str): Training output directory containing hf_format/ subdirectory
107+
108+
Returns:
109+
str: Path to the most recent checkpoint
110+
111+
Raises:
112+
ValueError: If no checkpoints are found
113+
"""
114+
# Get all checkpoint directories under hf_format
115+
checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*")
116+
checkpoint_dirs = glob.glob(checkpoint_pattern)
117+
118+
if not checkpoint_dirs:
119+
raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
120+
121+
# Find the most recently created checkpoint
122+
most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
123+
124+
return most_recent_checkpoint
125+
126+
127+
def main():
128+
parser = argparse.ArgumentParser(description=f'SFT Training Example: {model_name}')
129+
130+
# Required parameters
131+
parser.add_argument('--data-path', required=True,
132+
help='Path to training data (JSONL format)')
133+
parser.add_argument('--ckpt-output-dir', required=True,
134+
help='Directory to save checkpoints')
135+
136+
# Optional overrides
137+
parser.add_argument('--model-path', default=default_model_path,
138+
help=f'Model path or HuggingFace name (default: {default_model_path})')
139+
parser.add_argument('--num-epochs', type=int, default=default_num_epochs,
140+
help=f'Number of training epochs (default: {default_num_epochs})')
141+
parser.add_argument('--nproc-per-node', type=int, default=default_nproc_per_node,
142+
help=f'Number of GPUs (default: {default_nproc_per_node})')
143+
parser.add_argument('--max-tokens-per-gpu', type=int, default=example_max_tokens_per_gpu,
144+
help=f'Max tokens per GPU (default: {example_max_tokens_per_gpu})')
145+
parser.add_argument('--batch-size', type=int, default=example_batch_size,
146+
help=f'Effective batch size for training (default: {example_batch_size})')
147+
parser.add_argument('--learning-rate', type=float, default=example_learning_rate,
148+
help=f'Learning rate for training (default: {example_learning_rate})')
149+
parser.add_argument('--max-seq-len', type=int, default=example_max_seq_len,
150+
help=f'Max sequence length (default: {example_max_seq_len})')
151+
parser.add_argument('--model-weight', type=float, default=default_model_weight,
152+
help=f'Weight for trained model for interpolation (0.0-1.0, default: {default_model_weight})')
153+
154+
args = parser.parse_args()
155+
156+
if args.nproc_per_node < example_min_nproc_per_node:
157+
print(f"💡 Try --nproc-per-node {example_min_nproc_per_node} or larger if you see OOM errors")
158+
159+
# Granite 4.0 configuration
160+
print(f"🚀 SFT Training: {model_name}")
161+
print("=" * 50)
162+
print(f"Model: {args.model_path}")
163+
print(f"Data: {args.data_path}")
164+
print(f"Output: {args.ckpt_output_dir}")
165+
print(f"GPUs: {args.nproc_per_node}")
166+
print(f"Max tokens per GPU: {args.max_tokens_per_gpu:,}")
167+
print(f"Epochs: {args.num_epochs}")
168+
print(f"Batch size: {args.batch_size}")
169+
print(f"Learning rate: {args.learning_rate}")
170+
print(f"Max sequence length: {args.max_seq_len:,}")
171+
print(f"Model weight: {args.model_weight}")
172+
print()
173+
174+
# Training configuration optimized for Granite 4.0
175+
start_time = time.time()
176+
177+
try:
178+
sft(
179+
# Model and data
180+
model_path=args.model_path,
181+
data_path=args.data_path,
182+
ckpt_output_dir=args.ckpt_output_dir,
183+
184+
# Training parameters optimized for Granite 4.0
185+
num_epochs=args.num_epochs,
186+
effective_batch_size=args.batch_size,
187+
learning_rate=args.learning_rate,
188+
max_seq_len=args.max_seq_len,
189+
max_tokens_per_gpu=args.max_tokens_per_gpu,
190+
191+
# Data processing
192+
data_output_dir=data_output_dir,
193+
warmup_steps=100,
194+
save_samples=0, # 0 disables sample-based checkpointing, use epoch-based only
195+
196+
# Checkpointing
197+
checkpoint_at_epoch=True,
198+
accelerate_full_state_at_epoch=False, # Disable for smaller checkpoints (no auto-resumption)
199+
200+
# Single-node multi-GPU setup
201+
nproc_per_node=args.nproc_per_node,
202+
nnodes=1,
203+
# node_rank=0,
204+
# rdzv_id=102,
205+
# rdzv_endpoint="127.0.0.1:29500",
206+
207+
# Additional parameters to the backend
208+
**kwargs
209+
)
210+
211+
end_time = time.time()
212+
duration = end_time - start_time
213+
214+
print("=" * 50)
215+
print("✅ Training completed successfully!")
216+
print(f"⏱️ Duration: {duration/3600:.2f} hours")
217+
print(f"📁 Checkpoints: {args.ckpt_output_dir}/hf_format/")
218+
219+
most_recent_checkpoint = find_most_recent_checkpoint(args.ckpt_output_dir)
220+
print(f" Most recent checkpoint: {most_recent_checkpoint}")
221+
222+
trained_model_weight = args.model_weight
223+
if 0.0 < trained_model_weight and trained_model_weight < 1.0:
224+
from interpolator import interpolate_models
225+
226+
interp_model_path = interpolate_models(args.model_path, most_recent_checkpoint, trained_model_weight=trained_model_weight)
227+
228+
print("=" * 50)
229+
print("✅ Interpolation completed successfully!")
230+
print(f" Interpolated model checkpoint: {interp_model_path}")
231+
232+
except Exception as e:
233+
end_time = time.time()
234+
duration = end_time - start_time
235+
236+
print("=" * 50)
237+
print(f"❌ Training failed after {duration/60:.1f} minutes")
238+
print(f"Error: {e}")
239+
print()
240+
print("💡 Try reducing --max-tokens-per-gpu if you see OOM errors")
241+
sys.exit(1)
242+
243+
244+
if __name__ == "__main__":
245+
main()

0 commit comments

Comments
 (0)