55
66Example usage:
77 python interpolator.py \\
8- --model-path ibm-granite/granite-3.3-8b-instruct \\
9- --trained-model-path /path/to/checkpoint
8+ --model-path /path/to/base/model \\
9+ --trained-model-path /path/to/trained/ checkpoint
1010"""
1111# Standard
1212import argparse
@@ -47,15 +47,15 @@ def interpolate_models(
4747 else :
4848 model_kwargs ["torch_dtype" ] = torch_dtype
4949
50- # load original model
50+ # load base model
5151 model = AutoModelForCausalLM .from_pretrained (
5252 model_path ,
5353 ** model_kwargs ,
5454 )
5555 state_dict = model .state_dict ()
56- original_model_weight = 1 - trained_model_weight
56+ base_model_weight = 1 - trained_model_weight
5757 for key in state_dict .keys ():
58- state_dict [key ] = state_dict [key ] * original_model_weight
58+ state_dict [key ] = state_dict [key ] * base_model_weight
5959
6060 # load trained model
6161 trained_model = AutoModelForCausalLM .from_pretrained (
@@ -66,13 +66,15 @@ def interpolate_models(
6666 for key in state_dict .keys ():
6767 state_dict [key ] += trained_state_dict [key ] * trained_model_weight
6868
69- # save interpolated model
69+ # save merged model
7070 model .save_pretrained (output_model_path , state_dict = state_dict )
7171
7272 # copy tokenizer
7373 tokenizer = AutoTokenizer .from_pretrained (model_path )
7474 tokenizer .save_pretrained (output_model_path )
7575
76+ print (f"Merged model saved at { output_model_path } " )
77+
7678 return output_model_path
7779
7880
@@ -84,7 +86,7 @@ def parse_arguments():
8486 "--model-path" ,
8587 type = str ,
8688 required = True ,
87- help = "Path to the original model" ,
89+ help = "Path to the base model" ,
8890 )
8991 parser .add_argument (
9092 "--trained-model-path" ,
0 commit comments