Skip to content

Commit 210031b

Browse files
committed
Address bot comments
1 parent ec90208 commit 210031b

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

examples/scripts/interpolator.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,36 @@
1313

1414
# Third Party
1515
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
import torch
1617

1718

1819
def interpolate_models(
1920
model_path: str,
2021
trained_model_path: str,
2122
trained_model_weight: float = 0.5,
2223
output_model_path: str | None = None,
23-
torch_dtype: str | None = "bfloat16",
24+
torch_dtype: str | torch.dtype | None = "bfloat16",
2425
) -> str:
2526
if output_model_path is None:
2627
output_model_path = f"{trained_model_path}_interp"
2728

28-
model_kwargs: dict[str, any] = {}
29-
if torch_dtype is not None and torch_dtype != "auto":
30-
model_kwargs["torch_dtype"] = torch_dtype
29+
model_kwargs = {}
30+
if torch_dtype is not None:
31+
if isinstance(torch_dtype, str):
32+
_torch_dtype = torch_dtype.lower()
33+
if _torch_dtype == "auto":
34+
model_kwargs["torch_dtype"] = "auto"
35+
else:
36+
_map = {
37+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
38+
"float16": torch.float16, "fp16": torch.float16,
39+
"float32": torch.float32, "fp32": torch.float32,
40+
}
41+
if _torch_dtype not in _map:
42+
raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}")
43+
model_kwargs["torch_dtype"] = _map[_torch_dtype]
44+
else:
45+
model_kwargs["torch_dtype"] = torch_dtype
3146

3247
# load original model
3348
model = AutoModelForCausalLM.from_pretrained(
@@ -66,31 +81,31 @@ def parse_arguments():
6681
"--model-path",
6782
type=str,
6883
required=True,
69-
help="path to the original model",
84+
help="Path to the original model",
7085
)
7186
parser.add_argument(
7287
"--trained-model-path",
7388
type=str,
7489
required=True,
75-
help="path to the trained model",
90+
help="Path to the trained model",
7691
)
7792
parser.add_argument(
7893
"--trained-model-weight",
7994
type=float,
8095
default=0.5,
81-
help="weight for the trained model",
96+
help="Weight for the trained model",
8297
)
8398
parser.add_argument(
8499
"--output-model-path",
85100
type=str,
86101
default=None,
87-
help="path to the output model",
102+
help="Path to the output model",
88103
)
89104
parser.add_argument(
90105
"--torch-dtype",
91106
type=str,
92107
default="bfloat16",
93-
help="torch dtype",
108+
help="Torch dtype",
94109
)
95110
args = parser.parse_args()
96111
return args

0 commit comments

Comments
 (0)