|
13 | 13 |
|
14 | 14 | # Third Party |
15 | 15 | from transformers import AutoModelForCausalLM, AutoTokenizer |
| 16 | +import torch |
16 | 17 |
|
17 | 18 |
|
18 | 19 | def interpolate_models( |
19 | 20 | model_path: str, |
20 | 21 | trained_model_path: str, |
21 | 22 | trained_model_weight: float = 0.5, |
22 | 23 | output_model_path: str | None = None, |
23 | | - torch_dtype: str | None = "bfloat16", |
| 24 | + torch_dtype: str | torch.dtype | None = "bfloat16", |
24 | 25 | ) -> str: |
25 | 26 | if output_model_path is None: |
26 | 27 | output_model_path = f"{trained_model_path}_interp" |
27 | 28 |
|
28 | | - model_kwargs: dict[str, any] = {} |
29 | | - if torch_dtype is not None and torch_dtype != "auto": |
30 | | - model_kwargs["torch_dtype"] = torch_dtype |
| 29 | + model_kwargs = {} |
| 30 | + if torch_dtype is not None: |
| 31 | + if isinstance(torch_dtype, str): |
| 32 | + _torch_dtype = torch_dtype.lower() |
| 33 | + if _torch_dtype == "auto": |
| 34 | + model_kwargs["torch_dtype"] = "auto" |
| 35 | + else: |
| 36 | + _map = { |
| 37 | + "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, |
| 38 | + "float16": torch.float16, "fp16": torch.float16, |
| 39 | + "float32": torch.float32, "fp32": torch.float32, |
| 40 | + } |
| 41 | + if _torch_dtype not in _map: |
| 42 | + raise ValueError(f"Unsupported --torch-dtype: {torch_dtype}") |
| 43 | + model_kwargs["torch_dtype"] = _map[_torch_dtype] |
| 44 | + else: |
| 45 | + model_kwargs["torch_dtype"] = torch_dtype |
31 | 46 |
|
32 | 47 | # load original model |
33 | 48 | model = AutoModelForCausalLM.from_pretrained( |
@@ -66,31 +81,31 @@ def parse_arguments(): |
66 | 81 | "--model-path", |
67 | 82 | type=str, |
68 | 83 | required=True, |
69 | | - help="path to the original model", |
| 84 | + help="Path to the original model", |
70 | 85 | ) |
71 | 86 | parser.add_argument( |
72 | 87 | "--trained-model-path", |
73 | 88 | type=str, |
74 | 89 | required=True, |
75 | | - help="path to the trained model", |
| 90 | + help="Path to the trained model", |
76 | 91 | ) |
77 | 92 | parser.add_argument( |
78 | 93 | "--trained-model-weight", |
79 | 94 | type=float, |
80 | 95 | default=0.5, |
81 | | - help="weight for the trained model", |
| 96 | + help="Weight for the trained model", |
82 | 97 | ) |
83 | 98 | parser.add_argument( |
84 | 99 | "--output-model-path", |
85 | 100 | type=str, |
86 | 101 | default=None, |
87 | | - help="path to the output model", |
| 102 | + help="Path to the output model", |
88 | 103 | ) |
89 | 104 | parser.add_argument( |
90 | 105 | "--torch-dtype", |
91 | 106 | type=str, |
92 | 107 | default="bfloat16", |
93 | | - help="torch dtype", |
| 108 | + help="Torch dtype", |
94 | 109 | ) |
95 | 110 | args = parser.parse_args() |
96 | 111 | return args |
|
0 commit comments