Skip to content

Commit 98aeee5

Browse files
committed
final tests
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent c499891 commit 98aeee5

File tree

8 files changed

+37
-29
lines changed

8 files changed

+37
-29
lines changed

test/quantization/test_quant_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def test_8da4w_quantizer(self):
238238
def test_8da4w_gptq_quantizer(self):
239239
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
240240
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
241-
torchao._models.llama.model.use_index_put_for_kv_cache = True
242241
# should be similar to TorchCompileDynamicQuantizer
243242
precision = torch.bfloat16
244243
device = "cpu"
@@ -338,7 +337,6 @@ def test_8da4w_quantizer_eval(self):
338337
def test_gptq_quantizer_int4wo(self):
339338
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
340339
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
341-
torchao._models.llama.model.use_index_put_for_kv_cache = True
342340
precision = torch.bfloat16
343341
device = "cuda"
344342
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

torchao/_models/_eval.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ def __init__(
5252
pad_token=0,
5353
device="cpu",
5454
):
55-
super().__init__()
56-
self._tokenizer = tokenizer
55+
try:
56+
super().__init__()
57+
except TypeError:
58+
# lm_eval 0.4.2 removed the default init
59+
super().__init__("gpt2", device="cpu")
60+
61+
self.tokenizer = tokenizer
5762
self._device = torch.device(device)
5863
self.vocab_size = vocab_size
5964
self._max_seq_length = calibration_seq_length
@@ -74,9 +79,9 @@ def __init__(
7479
@property
7580
def eot_token_id(self):
7681
try:
77-
return self._tokenizer.eos_id()
82+
return self.tokenizer.eos_id()
7883
except:
79-
return self._tokenizer.eos_id
84+
return self.tokenizer.eos_id
8085

8186
@property
8287
def max_length(self):
@@ -96,16 +101,16 @@ def device(self):
96101

97102
def tok_encode(self, string: str, **kwargs):
98103
# TODO: verify this for multi-batch as well
99-
tokens = self._tokenizer.encode(string)
100-
if hasattr(self._tokenizer, "bos_id"):
104+
tokens = self.tokenizer.encode(string)
105+
if hasattr(self.tokenizer, "bos_id"):
101106
try:
102-
tokens = [self._tokenizer.bos_id()] + tokens
107+
tokens = [self.tokenizer.bos_id()] + tokens
103108
except:
104-
tokens = [self._tokenizer.bos_id] + tokens
109+
tokens = [self.tokenizer.bos_id] + tokens
105110
return tokens
106111

107112
def tok_decode(self, tokens):
108-
decoded = self._tokenizer.decode(tokens)
113+
decoded = self.tokenizer.decode(tokens)
109114
return decoded
110115

111116
def add_input(self, args):
@@ -185,9 +190,9 @@ def __init__(
185190
input_prep_func=None,
186191
device="cuda"
187192
):
188-
super().__init__(None, None)
193+
super().__init__(tokenizer, None)
189194
self._model = model
190-
self._tokenizer = tokenizer
195+
# self.tokenizer = tokenizer
191196
self._device = torch.device(device)
192197
self._max_seq_length = max_seq_length
193198

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tokenizer import get_tokenizer
2222
import time
2323
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
24-
from model import prepare_inputs_for_model
24+
from torchao._models.llama.model import prepare_inputs_for_model
2525

2626
torch._inductor.config.fx_graph_cache = True
2727
torch._inductor.config.force_fuse_int_mm_with_mul = True

torchao/_models/llama/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def device_sync(device):
3636
wd = Path(__file__).parent.parent.resolve()
3737
sys.path.append(str(wd))
3838

39-
from model import Transformer, prepare_inputs_for_model
40-
from tokenizer import get_tokenizer
39+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
40+
from torchao._models.llama.tokenizer import get_tokenizer
4141

4242
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
4343
q = torch.empty_like(probs_sort).exponential_(1)

torchao/_models/llama/tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class SentencePieceWrapper(TokenizerInterface):
3030
def __init__(self, model_path):
3131
super().__init__(model_path)
3232
self.processor = spm.SentencePieceProcessor(str(model_path))
33+
self.bos_token_id = self.bos_id()
34+
self.eos_token_id = self.eos_id()
3335

3436
def encode(self, text):
3537
return self.processor.EncodeAsIds(text)
@@ -86,6 +88,8 @@ def __init__(self, model_path):
8688
# BOS / EOS token IDs
8789
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
8890
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
91+
self.bos_token_id = self.bos_id()
92+
self.eos_token_id = self.eos_id()
8993

9094
def encode(self, text):
9195
return self.model.encode(text)

torchao/quantization/GPTQ.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
)
4040
aten = torch.ops.aten
4141

42-
# need this to fix the model so it works for GPTQ
43-
import torchao
44-
from torchao._models.llama.model import use_index_put_for_kv_cache
45-
torchao._models.llama.model.use_index_put_for_kv_cache = True
46-
4742
if not _lm_eval_available:
4843
logging.info("lm_eval is not installed, GPTQ may not be usable")
4944

@@ -86,7 +81,9 @@ def __init__(
8681

8782
# trace model for one input
8883
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
89-
model.cpu()(*one_input)
84+
# needed for GPTQ on the torchao llama model
85+
import torchao
86+
torchao._models.llama.model.use_index_put_for_kv_cache = True
9087
exported_model = torch._dynamo.export(
9188
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
9289
)(*one_input)

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
"swap_conv2d_1x1_to_linear"
1818
"safe_int_mm",
1919
"autoquant",
20-
"change_linears_to_autoquantizable",
21-
"change_autoquantizable_to_quantized",
2220
"get_scale",
2321
"SmoothFakeDynQuantMixin",
2422
"SmoothFakeDynamicallyQuantizedLinear",

torchao/quantization/autoquant.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
except:
1818
from torch._inductor.runtime.runtime_utils import do_bench
1919

20+
__all__ = [
21+
"AutoQuantizableLinearWeight",
22+
"autoquant",
23+
]
24+
25+
2026
aten = torch.ops.aten
2127

2228
AUTOQUANT_CACHE = {}
@@ -382,11 +388,11 @@ def from_float(cls, weight):
382388
# TODO this gets picked in places where it makes perf worse, why?
383389
]
384390

385-
def change_linears_to_autoquantizable(model, **kwargs):
391+
def _change_linears_to_autoquantizable(model, **kwargs):
386392
"""
387393
Converts all linear weight tensors to the
388394
AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed
389-
by running the model and then calling change_autoquantizable_to_quantized
395+
by running the model and then calling _change_autoquantizable_to_quantized
390396
"""
391397
from torchao.quantization.quant_api import _is_linear
392398
filter_fn = kwargs.pop("filter_fn", _is_linear)
@@ -401,7 +407,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
401407
filter_fn if filter_fn is not None else _is_linear,
402408
)
403409

404-
def change_autoquantizable_to_quantized(model, **kwargs):
410+
def _change_autoquantizable_to_quantized(model, **kwargs):
405411
"""
406412
Converts AutoQuantizableLinearWeight tensor subclasses
407413
to various quantized/non-quantized tensor subclasses depending
@@ -461,7 +467,7 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST,
461467
# autoquantization
462468
def autoquant_prehook(module, args, kwargs):
463469
module.forward_log_only(*args, **kwargs)
464-
change_autoquantizable_to_quantized(
470+
_change_autoquantizable_to_quantized(
465471
module,
466472
**aq_kwargs,
467473
)
@@ -470,7 +476,7 @@ def autoquant_prehook(module, args, kwargs):
470476

471477
# perform initial swap from linear weights
472478
# to AutoQuantizableLinearWeight
473-
change_linears_to_autoquantizable(
479+
_change_linears_to_autoquantizable(
474480
model,
475481
filter_fn=filter_fn,
476482
qtensor_class_list=qtensor_class_list,

0 commit comments

Comments
 (0)