|
6 | 6 |
|
7 | 7 | from functools import partial
|
8 | 8 | from pathlib import Path
|
9 |
| -from typing import Any, Dict, Optional |
| 9 | +from typing import Any, Dict, Optional, Union |
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.nn as nn
|
13 | 13 | import torch.nn.functional as F
|
| 14 | +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken |
| 15 | +from executorch.examples.models.llama2.tokenizer.tokenizer import ( |
| 16 | + Tokenizer, |
| 17 | + Tokenizer as SentencePieceTokenizer, |
| 18 | +) |
| 19 | +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear |
| 20 | +from lm_eval.api.model import LM |
| 21 | +from lm_eval.evaluator import evaluate |
| 22 | +from lm_eval.models.huggingface import HFLM as eval_wrapper |
| 23 | +from lm_eval.tasks import get_task_dict |
14 | 24 |
|
15 | 25 | from sentencepiece import SentencePieceProcessor
|
16 | 26 |
|
|
33 | 43 | fsLinear = nn.Linear
|
34 | 44 |
|
35 | 45 |
|
| 46 | +class EagerEvalWrapper(eval_wrapper): |
| 47 | + """ |
| 48 | + A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + model: torch.nn.Module, |
| 54 | + tokenizer: Union[SentencePieceTokenizer, Tiktoken], |
| 55 | + max_seq_length: Optional[int] = None, |
| 56 | + use_kv_cache: bool = False, |
| 57 | + ): |
| 58 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 59 | + super().__init__(device=device) |
| 60 | + self._model = model |
| 61 | + self._tokenizer = tokenizer |
| 62 | + self._device = torch.device(device) |
| 63 | + self._max_seq_length = 2048 if max_seq_length is None else max_seq_length |
| 64 | + self._use_kv_cache = use_kv_cache |
| 65 | + |
| 66 | + @property |
| 67 | + def eot_token_id(self): |
| 68 | + return self._tokenizer.eos_id |
| 69 | + |
| 70 | + @property |
| 71 | + def max_length(self): |
| 72 | + return self._max_seq_length |
| 73 | + |
| 74 | + @property |
| 75 | + def max_gen_toks(self): |
| 76 | + return 50 |
| 77 | + |
| 78 | + @property |
| 79 | + def batch_size(self): |
| 80 | + return 1 |
| 81 | + |
| 82 | + @property |
| 83 | + def device(self): |
| 84 | + return self._device |
| 85 | + |
| 86 | + def tok_encode(self, string: str, **kwargs): |
| 87 | + tokens = self._tokenizer.encode(string, bos=True, eos=False) |
| 88 | + encoded = torch.tensor(tokens, dtype=torch.int, device=self.device) |
| 89 | + # encoded is a pytorch tensor, but some internal logic in the |
| 90 | + # eval harness expects it to be a list instead |
| 91 | + # TODO: verify this for multi-batch as well |
| 92 | + encoded = encoded.tolist() |
| 93 | + return encoded |
| 94 | + |
| 95 | + def tok_decode(self, tokens): |
| 96 | + decoded = self._tokenizer.decode(tokens) |
| 97 | + return decoded |
| 98 | + |
| 99 | + def _model_call(self, inps): |
| 100 | + bsz, seq_len = inps.shape |
| 101 | + if self._use_kv_cache: |
| 102 | + pos_tensor = torch.arange( |
| 103 | + self._max_seq_length, dtype=torch.int64, device=self.device |
| 104 | + ) |
| 105 | + |
| 106 | + logits = self._model(inps[:, : self._max_seq_length], pos_tensor) |
| 107 | + return logits |
| 108 | + else: |
| 109 | + logits = self._model(inps) |
| 110 | + return logits |
| 111 | + |
| 112 | + def _model_generate(self, context, max_length, eos_token_id): |
| 113 | + raise Exception("unimplemented") |
| 114 | + |
| 115 | + |
| 116 | +@torch.no_grad() |
| 117 | +def eval( |
| 118 | + eval_wrapper: LM, |
| 119 | + tasks: Optional[list] = None, |
| 120 | + limit: Optional[int] = None, |
| 121 | +) -> dict: |
| 122 | + """ |
| 123 | + Evaluates a language model on a specified task using the lm-evaluation-harness library. |
| 124 | + Args: |
| 125 | + eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation |
| 126 | + task (str): The name of the evaluation task to perform. |
| 127 | + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). |
| 128 | + Returns: |
| 129 | + eval_results (dict): A dictionary of evaluation results for the specified task(s). |
| 130 | + """ |
| 131 | + if tasks is None: |
| 132 | + tasks = ["wikitext"] |
| 133 | + if "hendrycks_test" in tasks: |
| 134 | + tasks.remove("hendrycks_test") |
| 135 | + tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys()) |
| 136 | + task_dict = get_task_dict(tasks) |
| 137 | + eval_results = evaluate( |
| 138 | + eval_wrapper, |
| 139 | + task_dict, |
| 140 | + limit=limit, |
| 141 | + ) |
| 142 | + return eval_results |
| 143 | + |
| 144 | + |
| 145 | +def run_wikitext_eval(m, tokenizer_path, seq_len): |
| 146 | + print("run_wikitext_eval calibration...") |
| 147 | + print("tokenizer_path: ", tokenizer_path) |
| 148 | + tokenizer = Tokenizer(str(tokenizer_path)) |
| 149 | + eval_wrapper = EagerEvalWrapper( |
| 150 | + model=m, |
| 151 | + tokenizer=tokenizer, |
| 152 | + max_seq_length=seq_len, |
| 153 | + use_kv_cache=False, |
| 154 | + ) |
| 155 | + eval_results = eval( |
| 156 | + eval_wrapper, |
| 157 | + tasks=["wikitext"], |
| 158 | + # limit=128, |
| 159 | + limit=5, |
| 160 | + # limit=1, |
| 161 | + ) |
| 162 | + for task, res in eval_results["results"].items(): |
| 163 | + print(f"{task}: {res}") |
| 164 | + |
| 165 | + |
| 166 | +class LinearActFakeQuant(torch.nn.Module): |
| 167 | + def __init__(self, linear): |
| 168 | + super().__init__() |
| 169 | + self.linear = linear |
| 170 | + self.input_activation_fake_quant = torch.quantization.FakeQuantize( |
| 171 | + observer=torch.quantization.MovingAverageMinMaxObserver, |
| 172 | + dtype=torch.int32, |
| 173 | + quant_min=torch.iinfo(torch.uint16).min, |
| 174 | + quant_max=torch.iinfo(torch.uint16).max, |
| 175 | + ) |
| 176 | + self.output_activation_fake_quant = torch.quantization.FakeQuantize( |
| 177 | + observer=torch.quantization.MovingAverageMinMaxObserver, |
| 178 | + dtype=torch.int32, |
| 179 | + quant_min=torch.iinfo(torch.uint16).min, |
| 180 | + quant_max=torch.iinfo(torch.uint16).max, |
| 181 | + ) |
| 182 | + |
| 183 | + def forward(self, x): |
| 184 | + x = self.input_activation_fake_quant(x) |
| 185 | + return self.output_activation_fake_quant(self.linear(x)) |
| 186 | + |
| 187 | + |
| 188 | +def get_quant_params(activation_fake_quant): |
| 189 | + quant_min = activation_fake_quant.quant_min |
| 190 | + quant_max = activation_fake_quant.quant_max |
| 191 | + qparams = activation_fake_quant.calculate_qparams() |
| 192 | + scale = qparams[0] |
| 193 | + zero_point = qparams[1] |
| 194 | + return (quant_min, quant_max, scale, zero_point) |
| 195 | + |
| 196 | + |
| 197 | +class LinearActQuant(torch.nn.Module): |
| 198 | + |
| 199 | + def __init__(self, linear_fake_quant): |
| 200 | + super().__init__() |
| 201 | + self.linear_fake_quant = linear_fake_quant |
| 202 | + ( |
| 203 | + self.input_quant_min, |
| 204 | + self.input_quant_max, |
| 205 | + self.input_scale, |
| 206 | + self.input_zero_point, |
| 207 | + ) = get_quant_params(linear_fake_quant.input_activation_fake_quant) |
| 208 | + |
| 209 | + ( |
| 210 | + self.output_quant_min, |
| 211 | + self.output_quant_max, |
| 212 | + self.output_scale, |
| 213 | + self.output_zero_point, |
| 214 | + ) = get_quant_params(linear_fake_quant.output_activation_fake_quant) |
| 215 | + |
| 216 | + def forward(self, x): |
| 217 | + # Manually quantize the input tensor using observed min and max values |
| 218 | + q_tensor = torch.round(x / self.input_scale + self.input_zero_point) |
| 219 | + # Clip to ensure within the range [0, 255] |
| 220 | + q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max) |
| 221 | + # Dequantize to the original scale |
| 222 | + dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale |
| 223 | + |
| 224 | + linear_output = self.linear_fake_quant.linear(dequantized_tensor) |
| 225 | + |
| 226 | + # # Quantize the linear output tensor |
| 227 | + q_linear_output = torch.round( |
| 228 | + linear_output / self.output_scale + self.output_zero_point |
| 229 | + ) |
| 230 | + q_linear_output = torch.clamp( |
| 231 | + q_linear_output, self.output_quant_min, self.output_quant_max |
| 232 | + ) |
| 233 | + # Dequantize the linear output tensor |
| 234 | + dq_linear_output = ( |
| 235 | + q_linear_output - self.output_zero_point |
| 236 | + ) * self.output_scale |
| 237 | + |
| 238 | + return dq_linear_output |
| 239 | + |
| 240 | + |
| 241 | +def _replace_linear_q_act(module: torch.nn.Module, stage: str): |
| 242 | + for name, child in module.named_children(): |
| 243 | + if stage == "convert": |
| 244 | + if isinstance(child, LinearActFakeQuant): |
| 245 | + new_linear = LinearActQuant(child) |
| 246 | + setattr(module, name, new_linear) |
| 247 | + else: |
| 248 | + _replace_linear_q_act(child, stage) |
| 249 | + elif stage == "prepare": |
| 250 | + if isinstance(child, HQQLinear): |
| 251 | + new_linear = LinearActFakeQuant(child) |
| 252 | + setattr(module, name, new_linear) |
| 253 | + else: |
| 254 | + _replace_linear_q_act(child, stage) |
| 255 | + |
| 256 | + |
| 257 | +def replace_linear_q_act(module: torch.nn.Module, stage: str): |
| 258 | + _replace_linear_q_act( |
| 259 | + module, |
| 260 | + stage, |
| 261 | + ) |
| 262 | + |
| 263 | + |
| 264 | +def prepare(model): |
| 265 | + replace_linear_q_act(model, "prepare") |
| 266 | + |
| 267 | + |
| 268 | +def convert(model): |
| 269 | + replace_linear_q_act(model, "convert") |
| 270 | + |
| 271 | + |
36 | 272 | def quantize(
|
37 | 273 | model: torch.nn.Module,
|
38 | 274 | qmode: str,
|
@@ -127,6 +363,65 @@ def quantize(
|
127 | 363 | group_size,
|
128 | 364 | )
|
129 | 365 | model = gptq_quantizer.quantize(model, inputs)
|
| 366 | + return model |
| 367 | + elif qmode == "16a4w-hqq": |
| 368 | + print("running 16a4w-hqq") |
| 369 | + from hqq.core.quantize import BaseQuantizeConfig, HQQLinear |
| 370 | + |
| 371 | + def _replace_linear_16a4w_hqq( |
| 372 | + module: torch.nn.Module, |
| 373 | + quant_config, |
| 374 | + compute_dtype, |
| 375 | + del_orig=False, |
| 376 | + ): |
| 377 | + for name, child in module.named_children(): |
| 378 | + if isinstance(child, nn.Linear): |
| 379 | + new_linear = HQQLinear( |
| 380 | + child, |
| 381 | + quant_config, |
| 382 | + compute_dtype=compute_dtype, |
| 383 | + del_orig=True, |
| 384 | + device="cpu", |
| 385 | + ) |
| 386 | + setattr(module, name, new_linear) |
| 387 | + else: |
| 388 | + _replace_linear_16a4w_hqq( |
| 389 | + child, |
| 390 | + quant_config, |
| 391 | + compute_dtype, |
| 392 | + del_orig=False, |
| 393 | + ) |
| 394 | + |
| 395 | + def replace_linear_16a4w_hqq( |
| 396 | + module: torch.nn.Module, |
| 397 | + quant_config, |
| 398 | + compute_dtype, |
| 399 | + del_orig=False, |
| 400 | + ): |
| 401 | + _replace_linear_16a4w_hqq( |
| 402 | + module, |
| 403 | + quant_config, |
| 404 | + compute_dtype, |
| 405 | + del_orig=False, |
| 406 | + ) |
| 407 | + |
| 408 | + compute_dtype = torch.float32 # torch.bfloat16 #[torch.float16, torch.bfloat16] |
| 409 | + quant_config = BaseQuantizeConfig( |
| 410 | + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False |
| 411 | + ) |
| 412 | + print("before replace_linear_16a4w_hqq model: ", model) |
| 413 | + replace_linear_16a4w_hqq(model, quant_config, compute_dtype) |
| 414 | + print("after replace_linear_16a4w_hqq model: ", model) |
| 415 | + |
| 416 | + print("model before prepare: ", model) |
| 417 | + prepare(model) |
| 418 | + print("model after prepare: ", model) |
| 419 | + |
| 420 | + # Calibration with wikitext, currently only use 5 samples and can be fine tuned |
| 421 | + run_wikitext_eval(model, tokenizer_path, 128) |
| 422 | + print("model after calibrate: ", model) |
| 423 | + convert(model) |
| 424 | + |
130 | 425 | return model
|
131 | 426 | else:
|
132 | 427 | raise Exception(f"Unrecognized quantize mode: {qmode}")
|
|
0 commit comments