Skip to content

Commit 51011b4

Browse files
danielpatrickhugHDCharles
authored andcommitted
add testing script for gptq Multitensor
1 parent dadeb85 commit 51011b4

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed

torchao/quantization/test_gptq_mt.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import unittest
2+
import torch
3+
import os
4+
from pathlib import Path
5+
from torchao._models.llama.tokenizer import get_tokenizer
6+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
7+
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
8+
import sys
9+
from safetensors.torch import load_file # Import safetensors loader
10+
import torch.nn.functional as F
11+
12+
from torchao.quantization.utils import _lm_eval_available
13+
if _lm_eval_available:
14+
15+
import lm_eval
16+
try: # lm_eval version 0.4
17+
from lm_eval.evaluator import evaluate
18+
from lm_eval.models.huggingface import HFLM as eval_wrapper
19+
from lm_eval.tasks import get_task_dict
20+
except: # lm_eval version 0.3
21+
from lm_eval import base, evaluator, tasks
22+
23+
eval_wrapper = base.BaseLM
24+
get_task_dict = tasks.get_task_dict
25+
evaluate = evaluator.evaluate
26+
27+
class InputRecorder(eval_wrapper):
28+
def __init__(
29+
self,
30+
tokenizer,
31+
calibration_seq_length,
32+
input_prep_func=None,
33+
pad_calibration_inputs=False,
34+
vocab_size=32000,
35+
pad_token=0,
36+
device="cpu",
37+
):
38+
try:
39+
super().__init__()
40+
except TypeError:
41+
# lm_eval 0.4.2 removed the default init
42+
super().__init__("gpt2", device="cpu")
43+
44+
self.tokenizer = tokenizer
45+
self._device = torch.device(device)
46+
self.vocab_size = vocab_size
47+
self._max_seq_length = calibration_seq_length
48+
self.calibration_seq_length = calibration_seq_length
49+
50+
self.input_prep_func = (
51+
input_prep_func if input_prep_func is not None
52+
else lambda x: (x,)
53+
)
54+
55+
self.pad_calibration_inputs = pad_calibration_inputs
56+
self.pad_token = pad_token
57+
58+
self.inputs = []
59+
60+
@property
61+
def eot_token_id(self):
62+
try:
63+
return self.tokenizer.eos_id()
64+
except:
65+
return self.tokenizer.eos_id
66+
67+
@property
68+
def max_length(self):
69+
return self._max_seq_length
70+
71+
@property
72+
def max_gen_toks(self):
73+
return 50
74+
75+
@property
76+
def batch_size(self):
77+
return 1
78+
79+
@property
80+
def device(self):
81+
return self._device
82+
83+
def tok_encode(self, string: str, **kwargs):
84+
tokens = self.tokenizer.encode(string)
85+
if hasattr(self.tokenizer, "bos_id"):
86+
try:
87+
tokens = [self.tokenizer.bos_id()] + tokens
88+
except:
89+
tokens = [self.tokenizer.bos_id] + tokens
90+
return tokens
91+
92+
def tok_decode(self, tokens):
93+
decoded = self.tokenizer.decode(tokens)
94+
return decoded
95+
96+
def add_input(self, args):
97+
self.inputs.append(args)
98+
99+
def record_inputs(
100+
self,
101+
calibration_tasks,
102+
calibration_limit,
103+
):
104+
try:
105+
lm_eval.tasks.initialize_tasks()
106+
except:
107+
pass
108+
109+
task_dict = get_task_dict(calibration_tasks)
110+
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
111+
112+
evaluate(
113+
self,
114+
task_dict,
115+
limit=calibration_limit,
116+
)
117+
return self
118+
119+
def get_inputs(self):
120+
return self.inputs
121+
122+
def _model_call(self, inps):
123+
inps = inps.squeeze(0)
124+
T = len(inps)
125+
if (
126+
# can't use inputs that are too short when padding disabled
127+
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
128+
or
129+
# can't use inputs that actually use token we use for padding
130+
(self.pad_calibration_inputs and self.pad_token in inps)
131+
):
132+
# give random output
133+
return torch.randn(
134+
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
135+
)
136+
137+
# pad or truncate to the right size
138+
if T >= self.calibration_seq_length:
139+
inps = inps[: self.calibration_seq_length]
140+
else:
141+
inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T))
142+
143+
inps = inps.unsqueeze(0)
144+
model_in = self.input_prep_func(inps)
145+
146+
self.add_input(model_in)
147+
148+
# output `something` with correct shape to keep eval going
149+
return torch.randn(
150+
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
151+
)
152+
153+
def _model_generate(self, context, max_length, eos_token_id):
154+
raise Exception("unimplemented")
155+
156+
import logging
157+
import time
158+
159+
logging.basicConfig(level=logging.INFO)
160+
logger = logging.getLogger(__name__)
161+
162+
class TransformerEvalWrapper(InputRecorder):
163+
"""
164+
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
165+
"""
166+
def __init__(
167+
self,
168+
model,
169+
tokenizer,
170+
max_seq_length,
171+
input_prep_func=None,
172+
device="cuda"
173+
):
174+
super().__init__(tokenizer, None)
175+
self._model = model
176+
# self.tokenizer = tokenizer
177+
self._device = torch.device(device)
178+
self._max_seq_length = max_seq_length
179+
180+
# need to take inps and convert to corrent input
181+
# for model
182+
self.input_prep_func = (
183+
input_prep_func if input_prep_func is not None
184+
else lambda x: (x,)
185+
)
186+
187+
def _model_call(self, inps):
188+
print("Entering _model_call")
189+
print(f"Input shape: {inps.shape}")
190+
191+
input = self.input_prep_func(inps)
192+
print(f"Processed input shapes: {[x.shape for x in input]}")
193+
194+
input = [x.to(self._device) for x in input]
195+
print(f"Inputs moved to device: {self._device}")
196+
197+
max_seq_length = min(max(inps.size()), self.max_length)
198+
print(f"Max sequence length: {max_seq_length}")
199+
200+
print("Setting up caches")
201+
with torch.device(self._device):
202+
print(f"Device: {self._device}")
203+
print(f"Batch size: {self.batch_size}")
204+
print(f"Max sequence length: {max_seq_length}")
205+
self._model.setup_caches(self.batch_size, max_seq_length)
206+
print("Caches set up")
207+
208+
print("Running model")
209+
torch.save(input, "input.pt")
210+
logits = self._model(*input)
211+
print(f"Model run complete. Logits shape: {logits.shape}")
212+
return logits
213+
214+
215+
216+
def _model_generate(self, context, max_length, eos_token_id):
217+
raise Exception('unimplemented')
218+
219+
def run_eval(self, tasks, limit):
220+
logger.info(f"Starting evaluation on tasks: {tasks}")
221+
logger.info(f"Evaluation limit: {limit}")
222+
223+
try:
224+
logger.info("Initializing lm_eval tasks")
225+
lm_eval.tasks.initialize_tasks()
226+
except Exception as e:
227+
logger.warning(f"Failed to initialize tasks: {e}")
228+
logger.info("Continuing without initialization")
229+
230+
try:
231+
logger.info("Getting task dictionary")
232+
task_dict = get_task_dict(tasks)
233+
logger.info(f"Task dictionary: {task_dict}")
234+
except Exception as e:
235+
logger.error(f"Failed to get task dictionary: {e}")
236+
raise
237+
238+
logger.info("Starting evaluation")
239+
start_time = time.time()
240+
241+
try:
242+
with torch.no_grad():
243+
result = evaluate(
244+
self,
245+
task_dict,
246+
limit=limit,
247+
verbosity= "DEBUG"
248+
)
249+
except Exception as e:
250+
logger.error(f"Evaluation failed: {e}")
251+
raise
252+
253+
end_time = time.time()
254+
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")
255+
256+
logger.info("Evaluation results:")
257+
for task, res in result["results"].items():
258+
logger.info(f"{task}: {res}")
259+
260+
return result
261+
262+
263+
precision = torch.bfloat16
264+
device = "cuda"
265+
print("Loading model")
266+
checkpoint_path = Path("/teamspace/studios/this_studio/ao/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
267+
model = Transformer.from_name(checkpoint_path.parent.name)
268+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
269+
model.load_state_dict(checkpoint, assign=True)
270+
model = model.to(dtype=precision, device="cpu")
271+
model.eval()
272+
print("Model loaded")
273+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
274+
assert tokenizer_path.is_file(), tokenizer_path
275+
tokenizer = get_tokenizer( # pyre-ignore[28]
276+
tokenizer_path,
277+
"Llama-2-7b-chat-hf",
278+
)
279+
print("Tokenizer loaded")
280+
281+
282+
blocksize = 128
283+
percdamp = 0.01
284+
groupsize = 64
285+
calibration_tasks = ["wikitext"]
286+
calibration_limit = 1
287+
calibration_seq_length = 100
288+
input_prep_func = prepare_inputs_for_model
289+
pad_calibration_inputs = False
290+
print("Recording inputs")
291+
inputs = InputRecorder(
292+
tokenizer,
293+
calibration_seq_length,
294+
input_prep_func,
295+
pad_calibration_inputs,
296+
model.config.vocab_size,
297+
device="cpu",
298+
).record_inputs(
299+
calibration_tasks,
300+
calibration_limit,
301+
).get_inputs()
302+
print("Inputs recorded")
303+
quantizer = Int4WeightOnlyGPTQQuantizer(
304+
blocksize,
305+
percdamp,
306+
groupsize,
307+
)
308+
309+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
310+
multi = [
311+
MultiTensor([ inp for inp, _ in inputs]),
312+
MultiTensor([ inds for _, inds in inputs])
313+
]
314+
print("Quantizing model")
315+
model = quantizer.quantize(model, multi).cuda()
316+
print("Model quantized")
317+
print("Saving model and fixing state dict")
318+
regular_state_dict = model.state_dict()#defaultdict(torch.tensor)
319+
for key, value in model.state_dict().items():
320+
if isinstance(value, MultiTensor):
321+
regular_state_dict[key] = value.values[0]
322+
else:
323+
regular_state_dict[key] = value
324+
325+
model = Transformer.from_name(checkpoint_path.parent.name)
326+
remove = [k for k in regular_state_dict if "kv_cache" in k]
327+
for k in remove:
328+
del regular_state_dict[k]
329+
330+
model.load_state_dict(regular_state_dict, assign=True)
331+
torch.save(model.state_dict(), 'model.pth')
332+
print("Running evaluation")
333+
result = TransformerEvalWrapper(
334+
model.to("cpu"),
335+
tokenizer,
336+
model.config.block_size,
337+
prepare_inputs_for_model,
338+
"cpu",
339+
).run_eval(
340+
["wikitext"],
341+
1,
342+
)
343+
print(result)

0 commit comments

Comments
 (0)