Skip to content

Commit f485b6c

Browse files
authored
Add granite training example (#16)
* Add granite training example * Add scripts to the README.md * Address bot comments * Address bot comment * address bot's comments
1 parent 4b19ab3 commit f485b6c

File tree

4 files changed

+604
-0
lines changed

4 files changed

+604
-0
lines changed

examples/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ The SFT algorithm supports training language models on supervised datasets with
2727
- [SFT with Llama 3.1 8B](scripts/sft_llama_example.py) - Single-node multi-GPU training example with Llama 3.1 8B Instruct
2828
- [SFT with Phi 4 Mini](scripts/sft_phi_example.py) - Single-node multi-GPU training example with Phi 4 Mini Instruct
2929
- [SFT with GPT-OSS 20B](scripts/sft_gpt_oss_example.py) - Single-node multi-GPU training example with GPT-OSS 20B
30+
- [SFT with Granite 3.3 8B](scripts/sft_granite_example.py) - Single-node multi-GPU training example with Granite 3.3 8B Instruct
3031

3132
**Quick Example:**
3233
```python
@@ -60,6 +61,7 @@ The OSFT algorithm supports continual training of pre-trained or instruction-tun
6061
- [OSFT with Llama 3.1 8B](scripts/osft_llama_example.py) - Single-node multi-GPU training example with Llama 3.1 8B Instruct
6162
- [OSFT with Phi 4 Mini](scripts/osft_phi_example.py) - Single-node multi-GPU training example with Phi 4 Mini Instruct
6263
- [OSFT with GPT-OSS 20B](scripts/osft_gpt_oss_example.py) - Single-node multi-GPU training example with GPT-OSS 20B
64+
- [OSFT with Granite 3.3 8B](scripts/osft_granite_example.py) - Single-node multi-GPU training example with Granite 3.3 8B Instruct
6365
- [OSFT Continual Learning Example](scripts/osft_continual_learning_example.py) - Example script demonstrating continual learning without catastrophic forgetting
6466

6567
**Quick Example:**
@@ -101,6 +103,25 @@ estimate(training_method='osft',
101103
)
102104
```
103105

106+
### Model Interpolation (Experimental / In-Development)
107+
108+
training_hub has a utility for merging two checkpoints of the same model into one with linear interpolation.
109+
110+
**Script:**
111+
- [interpolator.py](scripts/interpolator.py) - Python script for model interpolation
112+
113+
**Command-Line Example:**
114+
```bash
115+
python interpolator.py --model-path /path/to/base/model --trained-model-path /path/to/trained/checkpoint
116+
```
117+
118+
**Python Example:**
119+
```python
120+
from interpolator import interpolate_models
121+
122+
interpolate_models("/path/to/base/model", "/path/to/trained/checkpoint")
123+
```
124+
104125
## Getting Started
105126

106127
1. **For detailed parameter documentation**: Check the relevant guide in `docs/`

examples/scripts/interpolator.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Simple Model Interpolator
3+
4+
The script takes two checkpoints of the same model and outputs a merged checkpoint with linear interpolation.
5+
6+
Example usage:
7+
python interpolator.py \\
8+
--model-path /path/to/base/model \\
9+
--trained-model-path /path/to/trained/checkpoint
10+
"""
11+
# Standard
12+
import argparse
13+
14+
# Third Party
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
import torch
17+
18+
19+
def interpolate_models(
20+
model_path: str,
21+
trained_model_path: str,
22+
trained_model_weight: float = 0.5,
23+
output_model_path: str | None = None,
24+
torch_dtype: str | torch.dtype | None = "bfloat16",
25+
) -> str:
26+
if output_model_path is None:
27+
output_model_path = f"{trained_model_path}_interp"
28+
29+
if not (0.0 <= trained_model_weight <= 1.0):
30+
raise ValueError(f"trained_model_weight must be in [0,1], got {trained_model_weight}")
31+
32+
model_kwargs = {}
33+
if torch_dtype is not None:
34+
if isinstance(torch_dtype, str):
35+
_torch_dtype = torch_dtype.lower()
36+
if _torch_dtype == "auto":
37+
model_kwargs["torch_dtype"] = "auto"
38+
else:
39+
_map = {
40+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
41+
"float16": torch.float16, "fp16": torch.float16,
42+
"float32": torch.float32, "fp32": torch.float32,
43+
}
44+
if _torch_dtype not in _map:
45+
raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}")
46+
model_kwargs["torch_dtype"] = _map[_torch_dtype]
47+
else:
48+
model_kwargs["torch_dtype"] = torch_dtype
49+
50+
# load base model
51+
model = AutoModelForCausalLM.from_pretrained(
52+
model_path,
53+
**model_kwargs,
54+
)
55+
state_dict = model.state_dict()
56+
base_model_weight = 1 - trained_model_weight
57+
for key in state_dict.keys():
58+
state_dict[key] = state_dict[key] * base_model_weight
59+
60+
# load trained model
61+
trained_model = AutoModelForCausalLM.from_pretrained(
62+
trained_model_path,
63+
**model_kwargs,
64+
)
65+
trained_state_dict = trained_model.state_dict()
66+
for key in state_dict.keys():
67+
state_dict[key] += trained_state_dict[key] * trained_model_weight
68+
69+
# save merged model
70+
model.save_pretrained(output_model_path, state_dict=state_dict)
71+
72+
# copy tokenizer
73+
tokenizer = AutoTokenizer.from_pretrained(model_path)
74+
tokenizer.save_pretrained(output_model_path)
75+
76+
print(f"Merged model saved at {output_model_path}")
77+
78+
return output_model_path
79+
80+
81+
def parse_arguments():
82+
parser = argparse.ArgumentParser(
83+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
84+
)
85+
parser.add_argument(
86+
"--model-path",
87+
type=str,
88+
required=True,
89+
help="Path to the base model",
90+
)
91+
parser.add_argument(
92+
"--trained-model-path",
93+
type=str,
94+
required=True,
95+
help="Path to the trained model",
96+
)
97+
parser.add_argument(
98+
"--trained-model-weight",
99+
type=float,
100+
default=0.5,
101+
help="Weight for the trained model",
102+
)
103+
parser.add_argument(
104+
"--output-model-path",
105+
type=str,
106+
default=None,
107+
help="Path to the output model",
108+
)
109+
parser.add_argument(
110+
"--torch-dtype",
111+
type=str,
112+
default="bfloat16",
113+
help="Torch dtype",
114+
)
115+
args = parser.parse_args()
116+
return args
117+
118+
119+
def main():
120+
args = parse_arguments()
121+
model_path: str = args.model_path
122+
trained_model_path: str = args.trained_model_path
123+
trained_model_weight: float = args.trained_model_weight
124+
output_model_path: str | None = args.output_model_path
125+
torch_dtype: str | None = args.torch_dtype
126+
127+
interpolate_models(
128+
model_path,
129+
trained_model_path,
130+
trained_model_weight=trained_model_weight,
131+
output_model_path=output_model_path,
132+
torch_dtype=torch_dtype,
133+
)
134+
135+
136+
if __name__ == "__main__":
137+
main()

0 commit comments

Comments
 (0)