Skip to content

Commit 9dc2c11

Browse files
authored
eval script fixes (#414)
Additional script fixes Summary: int4wo had an issue with device swap after quantization api (need to set device before quantize) int4wo-gptq had an issue with kv_cache model var not being set correctly (now set in GPTQ code) eval in general had an issue with lm_eval 0.4.2 (updates to tokenizer and eval harness) #404 [not eval] autoquant docs not showing up (added __all__ to autoquant), made autoquant low level apis priviate Test Plan: python eval.py -q int4wo-64 --compile wikitext: {'word_perplexity,none': 12.842987954345306, 'word_perplexity_stderr,none': 'N/A', 'byte _perplexity,none': 1.611855472207904, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.68 87223897240059, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} Reviewers: Subscribers: Tasks: Tags:
1 parent bc8599f commit 9dc2c11

File tree

8 files changed

+53
-38
lines changed

8 files changed

+53
-38
lines changed

test/quantization/test_quant_api.py

-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def test_8da4w_quantizer(self):
238238
def test_8da4w_gptq_quantizer(self):
239239
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
240240
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
241-
torchao._models.llama.model.use_index_put_for_kv_cache = True
242241
# should be similar to TorchCompileDynamicQuantizer
243242
precision = torch.bfloat16
244243
device = "cpu"
@@ -338,7 +337,6 @@ def test_8da4w_quantizer_eval(self):
338337
def test_gptq_quantizer_int4_weight_only(self):
339338
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
340339
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
341-
torchao._models.llama.model.use_index_put_for_kv_cache = True
342340
precision = torch.bfloat16
343341
device = "cuda"
344342
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

torchao/_models/_eval.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ def __init__(
5252
pad_token=0,
5353
device="cpu",
5454
):
55-
super().__init__()
56-
self._tokenizer = tokenizer
55+
try:
56+
super().__init__()
57+
except TypeError:
58+
# lm_eval 0.4.2 removed the default init
59+
super().__init__("gpt2", device="cpu")
60+
61+
self.tokenizer = tokenizer
5762
self._device = torch.device(device)
5863
self.vocab_size = vocab_size
5964
self._max_seq_length = calibration_seq_length
@@ -74,9 +79,9 @@ def __init__(
7479
@property
7580
def eot_token_id(self):
7681
try:
77-
return self._tokenizer.eos_id()
82+
return self.tokenizer.eos_id()
7883
except:
79-
return self._tokenizer.eos_id
84+
return self.tokenizer.eos_id
8085

8186
@property
8287
def max_length(self):
@@ -96,16 +101,16 @@ def device(self):
96101

97102
def tok_encode(self, string: str, **kwargs):
98103
# TODO: verify this for multi-batch as well
99-
tokens = self._tokenizer.encode(string)
100-
if hasattr(self._tokenizer, "bos_id"):
104+
tokens = self.tokenizer.encode(string)
105+
if hasattr(self.tokenizer, "bos_id"):
101106
try:
102-
tokens = [self._tokenizer.bos_id()] + tokens
107+
tokens = [self.tokenizer.bos_id()] + tokens
103108
except:
104-
tokens = [self._tokenizer.bos_id] + tokens
109+
tokens = [self.tokenizer.bos_id] + tokens
105110
return tokens
106111

107112
def tok_decode(self, tokens):
108-
decoded = self._tokenizer.decode(tokens)
113+
decoded = self.tokenizer.decode(tokens)
109114
return decoded
110115

111116
def add_input(self, args):
@@ -185,9 +190,9 @@ def __init__(
185190
input_prep_func=None,
186191
device="cuda"
187192
):
188-
super().__init__(None, None)
193+
super().__init__(tokenizer, None)
189194
self._model = model
190-
self._tokenizer = tokenizer
195+
# self.tokenizer = tokenizer
191196
self._device = torch.device(device)
192197
self._max_seq_length = max_seq_length
193198

torchao/_models/llama/eval.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
)
1515
from torchao.quantization.quant_api import (
16-
quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass
16+
quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass
1717

1818
)
1919
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2020

2121
from tokenizer import get_tokenizer
2222
import time
2323
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
24-
from model import prepare_inputs_for_model
24+
from torchao._models.llama.model import prepare_inputs_for_model
2525

2626
torch._inductor.config.fx_graph_cache = True
2727
torch._inductor.config.force_fuse_int_mm_with_mul = True
@@ -60,17 +60,18 @@ def run_evaluation(
6060

6161
if quantization:
6262
if "int8wo" in quantization:
63-
quantize(model, int8wo())
63+
quantize(model, int8_weight_only())
6464
if "int8dq" in quantization:
65-
quantize(model, int8da_int8w())
65+
quantize(model, int8_dynamic_activation_int8_weight())
6666
if "int4wo" in quantization and not "gptq" in quantization:
6767
groupsize=int(quantization.split("-")[-1])
6868
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
69-
quantize(model, int4wo(groupsize=groupsize))
69+
quantize(model.to(device), int4_weight_only(group_size=groupsize))
7070
if "int4wo" in quantization and "gptq" in quantization:
7171
groupsize=int(quantization.split("-")[-2])
7272
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
7373
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
74+
assert "cuda" in device, "int4 gptq quantization only works on cuda"
7475
inputs = InputRecorder(
7576
tokenizer,
7677
calibration_seq_length,
@@ -83,7 +84,7 @@ def run_evaluation(
8384
calibration_limit,
8485
).get_inputs()
8586

86-
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize)
87+
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device)
8788
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
8889
model = quantizer.quantize(model, inputs).to(device)
8990
else:

torchao/_models/llama/generate.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def device_sync(device):
3535
wd = Path(__file__).parent.parent.resolve()
3636
sys.path.append(str(wd))
3737

38-
from model import Transformer, prepare_inputs_for_model
39-
from tokenizer import get_tokenizer
38+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
39+
from torchao._models.llama.tokenizer import get_tokenizer
4040

4141
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
4242
q = torch.empty_like(probs_sort).exponential_(1)
@@ -189,21 +189,21 @@ def main(
189189
if quantization:
190190
from torchao.quantization.quant_api import (
191191
quantize,
192-
int8wo,
193-
int8da_int8w,
194-
int4wo,
192+
int8_weight_only,
193+
int8_dynamic_activation_int8_weight,
194+
int4_weight_only,
195195
autoquant,
196196
unwrap_tensor_subclass
197197
)
198198

199199
if "int8wo" in quantization:
200-
quantize(model, int8wo())
200+
quantize(model, int8_weight_only())
201201
if "int8dq" in quantization:
202-
quantize(model, int8da_int8w())
202+
quantize(model, int8_dynamic_activation_int8_weight())
203203
if "int4wo" in quantization:
204204
groupsize=int(quantization.split("-")[-1])
205205
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
206-
quantize(model, int4wo(groupsize=groupsize))
206+
quantize(model, int4_weight_only(groupsize=groupsize))
207207
if "autoquant" == quantization:
208208
model = autoquant(model, manual=True)
209209

torchao/_models/llama/tokenizer.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class SentencePieceWrapper(TokenizerInterface):
3030
def __init__(self, model_path):
3131
super().__init__(model_path)
3232
self.processor = spm.SentencePieceProcessor(str(model_path))
33+
self.bos_token_id = self.bos_id()
34+
self.eos_token_id = self.eos_id()
3335

3436
def encode(self, text):
3537
return self.processor.EncodeAsIds(text)
@@ -86,6 +88,8 @@ def __init__(self, model_path):
8688
# BOS / EOS token IDs
8789
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
8890
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
91+
self.bos_token_id = self.bos_id()
92+
self.eos_token_id = self.eos_id()
8993

9094
def encode(self, text):
9195
return self.model.encode(text)

torchao/quantization/GPTQ.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def __init__(
8181

8282
# trace model for one input
8383
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
84+
# needed for GPTQ on the torchao llama model
85+
import torchao
86+
torchao._models.llama.model.use_index_put_for_kv_cache = True
8487
exported_model = torch._dynamo.export(
8588
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
8689
)(*one_input)
@@ -95,7 +98,7 @@ def __init__(
9598
self.groupsize = groupsize
9699
self.inputs = inputs
97100
self.gptq_done = False
98-
self.debug = False
101+
self.debug = True
99102

100103
def configure_quantization_mode(
101104
self,
@@ -672,9 +675,9 @@ def quantize(
672675
class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
673676
def __init__(
674677
self,
675-
blocksize,
676-
percdamp,
677-
groupsize,
678+
blocksize=128,
679+
percdamp=0.01,
680+
groupsize=64,
678681
inner_k_tiles=8,
679682
padding_allowed=True,
680683
device: torch.device = torch.device("cuda"),

torchao/quantization/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
"swap_conv2d_1x1_to_linear"
1818
"safe_int_mm",
1919
"autoquant",
20-
"change_linears_to_autoquantizable",
21-
"change_autoquantizable_to_quantized",
2220
"get_scale",
2321
"SmoothFakeDynQuantMixin",
2422
"SmoothFakeDynamicallyQuantizedLinear",

torchao/quantization/autoquant.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
except:
1818
from torch._inductor.runtime.runtime_utils import do_bench
1919

20+
__all__ = [
21+
"AutoQuantizableLinearWeight",
22+
"autoquant",
23+
]
24+
25+
2026
aten = torch.ops.aten
2127

2228
AUTOQUANT_CACHE = {}
@@ -382,11 +388,11 @@ def from_float(cls, weight):
382388
AQInt8DynamicallyQuantizedLinearWeight,
383389
]
384390

385-
def change_linears_to_autoquantizable(model, **kwargs):
391+
def _change_linears_to_autoquantizable(model, **kwargs):
386392
"""
387393
Converts all linear weight tensors to the
388394
AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed
389-
by running the model and then calling change_autoquantizable_to_quantized
395+
by running the model and then calling _change_autoquantizable_to_quantized
390396
"""
391397
from torchao.quantization.quant_api import _is_linear
392398
filter_fn = kwargs.pop("filter_fn", _is_linear)
@@ -401,7 +407,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
401407
filter_fn if filter_fn is not None else _is_linear,
402408
)
403409

404-
def change_autoquantizable_to_quantized(model, **kwargs):
410+
def _change_autoquantizable_to_quantized(model, **kwargs):
405411
"""
406412
Converts AutoQuantizableLinearWeight tensor subclasses
407413
to various quantized/non-quantized tensor subclasses depending
@@ -490,7 +496,7 @@ def autoquant(
490496

491497
# perform initial swap from linear weights
492498
# to AutoQuantizableLinearWeight
493-
change_linears_to_autoquantizable(
499+
_change_linears_to_autoquantizable(
494500
model,
495501
filter_fn=filter_fn,
496502
qtensor_class_list=qtensor_class_list,
@@ -531,7 +537,7 @@ def autoquant_prehook(module, args, kwargs):
531537
# note the torch.compile wrapper (eval_frame) moves the assignment of any assigned
532538
# attributes to the inner model that didn't exist before, so we have to call delattr on the inner model
533539
def finalize_autoquant():
534-
change_autoquantizable_to_quantized(
540+
_change_autoquantizable_to_quantized(
535541
real_model,
536542
**aq_kwargs,
537543
)

0 commit comments

Comments
 (0)