Skip to content

Commit e718c1a

Browse files
committed
add 16a4w_hqq quant mode
Pull Request resolved: #3752 Prerequistie: install hqq following https://github.com/mobiusml/hqq Step 1: use hqq to quantize weight to 4bit Step 2: use static quant to quantize activation to 16bit Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration command: ``` python -m examples.models.llama2.eval_llama -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth --max_seq_len 129 -qmode 16a4w-hqq --limit 5 2>&1 | tee hqq_16a4w.log ``` ghstack-source-id: 228003732 Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
1 parent c8df1ab commit e718c1a

File tree

3 files changed

+298
-2
lines changed

3 files changed

+298
-2
lines changed

examples/models/llama2/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3636
from torch.nn.attention import SDPBackend
3737

38+
from examples.portable.utils import export_to_edge, save_pte_program
3839
from ..model_factory import EagerModelFactory
3940

4041
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser:
119119
"--quantization_mode",
120120
type=str,
121121
default=None,
122-
choices=["int8", "8da4w", "8da4w-gptq"],
122+
choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"],
123123
help="type of quantization",
124124
)
125125

examples/models/llama2/source_transformation/quantize.py

Lines changed: 296 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,21 @@
66

77
from functools import partial
88
from pathlib import Path
9-
from typing import Any, Dict, Optional
9+
from typing import Any, Dict, Optional, Union
1010

1111
import torch
1212
import torch.nn as nn
1313
import torch.nn.functional as F
14+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
15+
from executorch.examples.models.llama2.tokenizer.tokenizer import (
16+
Tokenizer,
17+
Tokenizer as SentencePieceTokenizer,
18+
)
19+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
20+
from lm_eval.api.model import LM
21+
from lm_eval.evaluator import evaluate
22+
from lm_eval.models.huggingface import HFLM as eval_wrapper
23+
from lm_eval.tasks import get_task_dict
1424

1525
from sentencepiece import SentencePieceProcessor
1626

@@ -33,6 +43,232 @@
3343
fsLinear = nn.Linear
3444

3545

46+
class EagerEvalWrapper(eval_wrapper):
47+
"""
48+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model: torch.nn.Module,
54+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
55+
max_seq_length: Optional[int] = None,
56+
use_kv_cache: bool = False,
57+
):
58+
device = "cuda" if torch.cuda.is_available() else "cpu"
59+
super().__init__(device=device)
60+
self._model = model
61+
self._tokenizer = tokenizer
62+
self._device = torch.device(device)
63+
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
64+
self._use_kv_cache = use_kv_cache
65+
66+
@property
67+
def eot_token_id(self):
68+
return self._tokenizer.eos_id
69+
70+
@property
71+
def max_length(self):
72+
return self._max_seq_length
73+
74+
@property
75+
def max_gen_toks(self):
76+
return 50
77+
78+
@property
79+
def batch_size(self):
80+
return 1
81+
82+
@property
83+
def device(self):
84+
return self._device
85+
86+
def tok_encode(self, string: str, **kwargs):
87+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
88+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
89+
# encoded is a pytorch tensor, but some internal logic in the
90+
# eval harness expects it to be a list instead
91+
# TODO: verify this for multi-batch as well
92+
encoded = encoded.tolist()
93+
return encoded
94+
95+
def tok_decode(self, tokens):
96+
decoded = self._tokenizer.decode(tokens)
97+
return decoded
98+
99+
def _model_call(self, inps):
100+
bsz, seq_len = inps.shape
101+
if self._use_kv_cache:
102+
pos_tensor = torch.arange(
103+
self._max_seq_length, dtype=torch.int64, device=self.device
104+
)
105+
106+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
107+
return logits
108+
else:
109+
logits = self._model(inps)
110+
return logits
111+
112+
def _model_generate(self, context, max_length, eos_token_id):
113+
raise Exception("unimplemented")
114+
115+
116+
@torch.no_grad()
117+
def eval(
118+
eval_wrapper: LM,
119+
tasks: Optional[list] = None,
120+
limit: Optional[int] = None,
121+
) -> dict:
122+
"""
123+
Evaluates a language model on a specified task using the lm-evaluation-harness library.
124+
Args:
125+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
126+
task (str): The name of the evaluation task to perform.
127+
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
128+
Returns:
129+
eval_results (dict): A dictionary of evaluation results for the specified task(s).
130+
"""
131+
if tasks is None:
132+
tasks = ["wikitext"]
133+
if "hendrycks_test" in tasks:
134+
tasks.remove("hendrycks_test")
135+
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
136+
task_dict = get_task_dict(tasks)
137+
eval_results = evaluate(
138+
eval_wrapper,
139+
task_dict,
140+
limit=limit,
141+
)
142+
return eval_results
143+
144+
145+
def run_wikitext_eval(m, tokenizer_path, seq_len):
146+
print("run_wikitext_eval calibration...")
147+
print("tokenizer_path: ", tokenizer_path)
148+
tokenizer = Tokenizer(str(tokenizer_path))
149+
eval_wrapper = EagerEvalWrapper(
150+
model=m,
151+
tokenizer=tokenizer,
152+
max_seq_length=seq_len,
153+
use_kv_cache=False,
154+
)
155+
eval_results = eval(
156+
eval_wrapper,
157+
tasks=["wikitext"],
158+
# limit=128,
159+
limit=5,
160+
# limit=1,
161+
)
162+
for task, res in eval_results["results"].items():
163+
print(f"{task}: {res}")
164+
165+
166+
class LinearActFakeQuant(torch.nn.Module):
167+
def __init__(self, linear):
168+
super().__init__()
169+
self.linear = linear
170+
self.input_activation_fake_quant = torch.quantization.FakeQuantize(
171+
observer=torch.quantization.MovingAverageMinMaxObserver,
172+
dtype=torch.int32,
173+
quant_min=torch.iinfo(torch.uint16).min,
174+
quant_max=torch.iinfo(torch.uint16).max,
175+
)
176+
self.output_activation_fake_quant = torch.quantization.FakeQuantize(
177+
observer=torch.quantization.MovingAverageMinMaxObserver,
178+
dtype=torch.int32,
179+
quant_min=torch.iinfo(torch.uint16).min,
180+
quant_max=torch.iinfo(torch.uint16).max,
181+
)
182+
183+
def forward(self, x):
184+
x = self.input_activation_fake_quant(x)
185+
return self.output_activation_fake_quant(self.linear(x))
186+
187+
188+
def get_quant_params(activation_fake_quant):
189+
quant_min = activation_fake_quant.quant_min
190+
quant_max = activation_fake_quant.quant_max
191+
qparams = activation_fake_quant.calculate_qparams()
192+
scale = qparams[0]
193+
zero_point = qparams[1]
194+
return (quant_min, quant_max, scale, zero_point)
195+
196+
197+
class LinearActQuant(torch.nn.Module):
198+
199+
def __init__(self, linear_fake_quant):
200+
super().__init__()
201+
self.linear_fake_quant = linear_fake_quant
202+
(
203+
self.input_quant_min,
204+
self.input_quant_max,
205+
self.input_scale,
206+
self.input_zero_point,
207+
) = get_quant_params(linear_fake_quant.input_activation_fake_quant)
208+
209+
(
210+
self.output_quant_min,
211+
self.output_quant_max,
212+
self.output_scale,
213+
self.output_zero_point,
214+
) = get_quant_params(linear_fake_quant.output_activation_fake_quant)
215+
216+
def forward(self, x):
217+
# Manually quantize the input tensor using observed min and max values
218+
q_tensor = torch.round(x / self.input_scale + self.input_zero_point)
219+
# Clip to ensure within the range [0, 255]
220+
q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max)
221+
# Dequantize to the original scale
222+
dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale
223+
224+
linear_output = self.linear_fake_quant.linear(dequantized_tensor)
225+
226+
# # Quantize the linear output tensor
227+
q_linear_output = torch.round(
228+
linear_output / self.output_scale + self.output_zero_point
229+
)
230+
q_linear_output = torch.clamp(
231+
q_linear_output, self.output_quant_min, self.output_quant_max
232+
)
233+
# Dequantize the linear output tensor
234+
dq_linear_output = (
235+
q_linear_output - self.output_zero_point
236+
) * self.output_scale
237+
238+
return dq_linear_output
239+
240+
241+
def _replace_linear_q_act(module: torch.nn.Module, stage: str):
242+
for name, child in module.named_children():
243+
if stage == "convert":
244+
if isinstance(child, LinearActFakeQuant):
245+
new_linear = LinearActQuant(child)
246+
setattr(module, name, new_linear)
247+
else:
248+
_replace_linear_q_act(child, stage)
249+
elif stage == "prepare":
250+
if isinstance(child, HQQLinear):
251+
new_linear = LinearActFakeQuant(child)
252+
setattr(module, name, new_linear)
253+
else:
254+
_replace_linear_q_act(child, stage)
255+
256+
257+
def replace_linear_q_act(module: torch.nn.Module, stage: str):
258+
_replace_linear_q_act(
259+
module,
260+
stage,
261+
)
262+
263+
264+
def prepare(model):
265+
replace_linear_q_act(model, "prepare")
266+
267+
268+
def convert(model):
269+
replace_linear_q_act(model, "convert")
270+
271+
36272
def quantize(
37273
model: torch.nn.Module,
38274
qmode: str,
@@ -127,6 +363,65 @@ def quantize(
127363
group_size,
128364
)
129365
model = gptq_quantizer.quantize(model, inputs)
366+
return model
367+
elif qmode == "16a4w-hqq":
368+
print("running 16a4w-hqq")
369+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
370+
371+
def _replace_linear_16a4w_hqq(
372+
module: torch.nn.Module,
373+
quant_config,
374+
compute_dtype,
375+
del_orig=False,
376+
):
377+
for name, child in module.named_children():
378+
if isinstance(child, nn.Linear):
379+
new_linear = HQQLinear(
380+
child,
381+
quant_config,
382+
compute_dtype=compute_dtype,
383+
del_orig=True,
384+
device="cpu",
385+
)
386+
setattr(module, name, new_linear)
387+
else:
388+
_replace_linear_16a4w_hqq(
389+
child,
390+
quant_config,
391+
compute_dtype,
392+
del_orig=False,
393+
)
394+
395+
def replace_linear_16a4w_hqq(
396+
module: torch.nn.Module,
397+
quant_config,
398+
compute_dtype,
399+
del_orig=False,
400+
):
401+
_replace_linear_16a4w_hqq(
402+
module,
403+
quant_config,
404+
compute_dtype,
405+
del_orig=False,
406+
)
407+
408+
compute_dtype = torch.float32 # torch.bfloat16 #[torch.float16, torch.bfloat16]
409+
quant_config = BaseQuantizeConfig(
410+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
411+
)
412+
print("before replace_linear_16a4w_hqq model: ", model)
413+
replace_linear_16a4w_hqq(model, quant_config, compute_dtype)
414+
print("after replace_linear_16a4w_hqq model: ", model)
415+
416+
print("model before prepare: ", model)
417+
prepare(model)
418+
print("model after prepare: ", model)
419+
420+
# Calibration with wikitext, currently only use 5 samples and can be fine tuned
421+
run_wikitext_eval(model, tokenizer_path, 128)
422+
print("model after calibrate: ", model)
423+
convert(model)
424+
130425
return model
131426
else:
132427
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)