Skip to content

Commit 2cd1d2a

Browse files
committed
address bot's comments
1 parent b6ca02f commit 2cd1d2a

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

examples/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ training_hub has a utility for merging two checkpoints of the same model into on
112112

113113
**Command-Line Example:**
114114
```bash
115-
python interpolator.py --model-path ibm-granite/granite-3.3-8b-instruct --trained-model-path /path/to/checkpoint
115+
python interpolator.py --model-path /path/to/base/model --trained-model-path /path/to/trained/checkpoint
116116
```
117117

118118
**Python Example:**
119119
```python
120120
from interpolator import interpolate_models
121121

122-
interpolate_models("ibm-granite/granite-3.3-8b-instruct", "/path/to/checkpoint")
122+
interpolate_models("/path/to/base/model", "/path/to/trained/checkpoint")
123123
```
124124

125125
## Getting Started

examples/scripts/interpolator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
66
Example usage:
77
python interpolator.py \\
8-
--model-path ibm-granite/granite-3.3-8b-instruct \\
9-
--trained-model-path /path/to/checkpoint
8+
--model-path /path/to/base/model \\
9+
--trained-model-path /path/to/trained/checkpoint
1010
"""
1111
# Standard
1212
import argparse
@@ -47,15 +47,15 @@ def interpolate_models(
4747
else:
4848
model_kwargs["torch_dtype"] = torch_dtype
4949

50-
# load original model
50+
# load base model
5151
model = AutoModelForCausalLM.from_pretrained(
5252
model_path,
5353
**model_kwargs,
5454
)
5555
state_dict = model.state_dict()
56-
original_model_weight = 1 - trained_model_weight
56+
base_model_weight = 1 - trained_model_weight
5757
for key in state_dict.keys():
58-
state_dict[key] = state_dict[key] * original_model_weight
58+
state_dict[key] = state_dict[key] * base_model_weight
5959

6060
# load trained model
6161
trained_model = AutoModelForCausalLM.from_pretrained(
@@ -66,13 +66,15 @@ def interpolate_models(
6666
for key in state_dict.keys():
6767
state_dict[key] += trained_state_dict[key] * trained_model_weight
6868

69-
# save interpolated model
69+
# save merged model
7070
model.save_pretrained(output_model_path, state_dict=state_dict)
7171

7272
# copy tokenizer
7373
tokenizer = AutoTokenizer.from_pretrained(model_path)
7474
tokenizer.save_pretrained(output_model_path)
7575

76+
print(f"Merged model saved at {output_model_path}")
77+
7678
return output_model_path
7779

7880

@@ -84,7 +86,7 @@ def parse_arguments():
8486
"--model-path",
8587
type=str,
8688
required=True,
87-
help="Path to the original model",
89+
help="Path to the base model",
8890
)
8991
parser.add_argument(
9092
"--trained-model-path",

examples/scripts/osft_granite_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def main():
120120
parser.add_argument('--learning-rate', type=float, default=default_learning_rate,
121121
help=f'Learning rate for training (default: {default_learning_rate})')
122122
parser.add_argument('--unmask-messages', action='store_true', default=False,
123-
help='Unmask messages during training (default: False)')
123+
help='Unmask all non-system messages during training, otherwise only unmasks assistant messages (default: False)')
124124
parser.add_argument('--batch-size', type=int, default=default_batch_size,
125125
help=f'Effective batch size for training (default: {default_batch_size})')
126126
parser.add_argument('--max-seq-len', type=int, default=default_max_seq_len,

0 commit comments

Comments
 (0)