1
+ import unittest
2
+ import torch
3
+ import os
4
+ from pathlib import Path
5
+ from torchao ._models .llama .tokenizer import get_tokenizer
6
+ from torchao ._models .llama .model import Transformer , prepare_inputs_for_model
7
+ from torchao .quantization .GPTQ_MT import Int4WeightOnlyGPTQQuantizer , MultiTensor
8
+ import sys
9
+ from safetensors .torch import load_file # Import safetensors loader
10
+ import torch .nn .functional as F
11
+
12
+ from torchao .quantization .utils import _lm_eval_available
13
+ if _lm_eval_available :
14
+
15
+ import lm_eval
16
+ try : # lm_eval version 0.4
17
+ from lm_eval .evaluator import evaluate
18
+ from lm_eval .models .huggingface import HFLM as eval_wrapper
19
+ from lm_eval .tasks import get_task_dict
20
+ except : # lm_eval version 0.3
21
+ from lm_eval import base , evaluator , tasks
22
+
23
+ eval_wrapper = base .BaseLM
24
+ get_task_dict = tasks .get_task_dict
25
+ evaluate = evaluator .evaluate
26
+
27
+ class InputRecorder (eval_wrapper ):
28
+ def __init__ (
29
+ self ,
30
+ tokenizer ,
31
+ calibration_seq_length ,
32
+ input_prep_func = None ,
33
+ pad_calibration_inputs = False ,
34
+ vocab_size = 32000 ,
35
+ pad_token = 0 ,
36
+ device = "cpu" ,
37
+ ):
38
+ try :
39
+ super ().__init__ ()
40
+ except TypeError :
41
+ # lm_eval 0.4.2 removed the default init
42
+ super ().__init__ ("gpt2" , device = "cpu" )
43
+
44
+ self .tokenizer = tokenizer
45
+ self ._device = torch .device (device )
46
+ self .vocab_size = vocab_size
47
+ self ._max_seq_length = calibration_seq_length
48
+ self .calibration_seq_length = calibration_seq_length
49
+
50
+ self .input_prep_func = (
51
+ input_prep_func if input_prep_func is not None
52
+ else lambda x : (x ,)
53
+ )
54
+
55
+ self .pad_calibration_inputs = pad_calibration_inputs
56
+ self .pad_token = pad_token
57
+
58
+ self .inputs = []
59
+
60
+ @property
61
+ def eot_token_id (self ):
62
+ try :
63
+ return self .tokenizer .eos_id ()
64
+ except :
65
+ return self .tokenizer .eos_id
66
+
67
+ @property
68
+ def max_length (self ):
69
+ return self ._max_seq_length
70
+
71
+ @property
72
+ def max_gen_toks (self ):
73
+ return 50
74
+
75
+ @property
76
+ def batch_size (self ):
77
+ return 1
78
+
79
+ @property
80
+ def device (self ):
81
+ return self ._device
82
+
83
+ def tok_encode (self , string : str , ** kwargs ):
84
+ tokens = self .tokenizer .encode (string )
85
+ if hasattr (self .tokenizer , "bos_id" ):
86
+ try :
87
+ tokens = [self .tokenizer .bos_id ()] + tokens
88
+ except :
89
+ tokens = [self .tokenizer .bos_id ] + tokens
90
+ return tokens
91
+
92
+ def tok_decode (self , tokens ):
93
+ decoded = self .tokenizer .decode (tokens )
94
+ return decoded
95
+
96
+ def add_input (self , args ):
97
+ self .inputs .append (args )
98
+
99
+ def record_inputs (
100
+ self ,
101
+ calibration_tasks ,
102
+ calibration_limit ,
103
+ ):
104
+ try :
105
+ lm_eval .tasks .initialize_tasks ()
106
+ except :
107
+ pass
108
+
109
+ task_dict = get_task_dict (calibration_tasks )
110
+ print ("Obtaining GPTQ calibration inputs on: " , calibration_tasks )
111
+
112
+ evaluate (
113
+ self ,
114
+ task_dict ,
115
+ limit = calibration_limit ,
116
+ )
117
+ return self
118
+
119
+ def get_inputs (self ):
120
+ return self .inputs
121
+
122
+ def _model_call (self , inps ):
123
+ inps = inps .squeeze (0 )
124
+ T = len (inps )
125
+ if (
126
+ # can't use inputs that are too short when padding disabled
127
+ (T < self .calibration_seq_length and not self .pad_calibration_inputs )
128
+ or
129
+ # can't use inputs that actually use token we use for padding
130
+ (self .pad_calibration_inputs and self .pad_token in inps )
131
+ ):
132
+ # give random output
133
+ return torch .randn (
134
+ (1 , T , self .vocab_size ), dtype = torch .bfloat16 , device = self ._device
135
+ )
136
+
137
+ # pad or truncate to the right size
138
+ if T >= self .calibration_seq_length :
139
+ inps = inps [: self .calibration_seq_length ]
140
+ else :
141
+ inps = F .pad (inps , (self .pad_token , self .calibration_seq_length - T ))
142
+
143
+ inps = inps .unsqueeze (0 )
144
+ model_in = self .input_prep_func (inps )
145
+
146
+ self .add_input (model_in )
147
+
148
+ # output `something` with correct shape to keep eval going
149
+ return torch .randn (
150
+ (1 , T , self .vocab_size ), dtype = torch .bfloat16 , device = self ._device
151
+ )
152
+
153
+ def _model_generate (self , context , max_length , eos_token_id ):
154
+ raise Exception ("unimplemented" )
155
+
156
+ import logging
157
+ import time
158
+
159
+ logging .basicConfig (level = logging .INFO )
160
+ logger = logging .getLogger (__name__ )
161
+
162
+ class TransformerEvalWrapper (InputRecorder ):
163
+ """
164
+ A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
165
+ """
166
+ def __init__ (
167
+ self ,
168
+ model ,
169
+ tokenizer ,
170
+ max_seq_length ,
171
+ input_prep_func = None ,
172
+ device = "cuda"
173
+ ):
174
+ super ().__init__ (tokenizer , None )
175
+ self ._model = model
176
+ # self.tokenizer = tokenizer
177
+ self ._device = torch .device (device )
178
+ self ._max_seq_length = max_seq_length
179
+
180
+ # need to take inps and convert to corrent input
181
+ # for model
182
+ self .input_prep_func = (
183
+ input_prep_func if input_prep_func is not None
184
+ else lambda x : (x ,)
185
+ )
186
+
187
+ def _model_call (self , inps ):
188
+ print ("Entering _model_call" )
189
+ print (f"Input shape: { inps .shape } " )
190
+
191
+ input = self .input_prep_func (inps )
192
+ print (f"Processed input shapes: { [x .shape for x in input ]} " )
193
+
194
+ input = [x .to (self ._device ) for x in input ]
195
+ print (f"Inputs moved to device: { self ._device } " )
196
+
197
+ max_seq_length = min (max (inps .size ()), self .max_length )
198
+ print (f"Max sequence length: { max_seq_length } " )
199
+
200
+ print ("Setting up caches" )
201
+ with torch .device (self ._device ):
202
+ print (f"Device: { self ._device } " )
203
+ print (f"Batch size: { self .batch_size } " )
204
+ print (f"Max sequence length: { max_seq_length } " )
205
+ self ._model .setup_caches (self .batch_size , max_seq_length )
206
+ print ("Caches set up" )
207
+
208
+ print ("Running model" )
209
+ torch .save (input , "input.pt" )
210
+ logits = self ._model (* input )
211
+ print (f"Model run complete. Logits shape: { logits .shape } " )
212
+ return logits
213
+
214
+
215
+
216
+ def _model_generate (self , context , max_length , eos_token_id ):
217
+ raise Exception ('unimplemented' )
218
+
219
+ def run_eval (self , tasks , limit ):
220
+ logger .info (f"Starting evaluation on tasks: { tasks } " )
221
+ logger .info (f"Evaluation limit: { limit } " )
222
+
223
+ try :
224
+ logger .info ("Initializing lm_eval tasks" )
225
+ lm_eval .tasks .initialize_tasks ()
226
+ except Exception as e :
227
+ logger .warning (f"Failed to initialize tasks: { e } " )
228
+ logger .info ("Continuing without initialization" )
229
+
230
+ try :
231
+ logger .info ("Getting task dictionary" )
232
+ task_dict = get_task_dict (tasks )
233
+ logger .info (f"Task dictionary: { task_dict } " )
234
+ except Exception as e :
235
+ logger .error (f"Failed to get task dictionary: { e } " )
236
+ raise
237
+
238
+ logger .info ("Starting evaluation" )
239
+ start_time = time .time ()
240
+
241
+ try :
242
+ with torch .no_grad ():
243
+ result = evaluate (
244
+ self ,
245
+ task_dict ,
246
+ limit = limit ,
247
+ verbosity = "DEBUG"
248
+ )
249
+ except Exception as e :
250
+ logger .error (f"Evaluation failed: { e } " )
251
+ raise
252
+
253
+ end_time = time .time ()
254
+ logger .info (f"Evaluation completed in { end_time - start_time :.2f} seconds" )
255
+
256
+ logger .info ("Evaluation results:" )
257
+ for task , res in result ["results" ].items ():
258
+ logger .info (f"{ task } : { res } " )
259
+
260
+ return result
261
+
262
+
263
+ precision = torch .bfloat16
264
+ device = "cuda"
265
+ print ("Loading model" )
266
+ checkpoint_path = Path ("/teamspace/studios/this_studio/ao/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" )
267
+ model = Transformer .from_name (checkpoint_path .parent .name )
268
+ checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
269
+ model .load_state_dict (checkpoint , assign = True )
270
+ model = model .to (dtype = precision , device = "cpu" )
271
+ model .eval ()
272
+ print ("Model loaded" )
273
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
274
+ assert tokenizer_path .is_file (), tokenizer_path
275
+ tokenizer = get_tokenizer ( # pyre-ignore[28]
276
+ tokenizer_path ,
277
+ "Llama-2-7b-chat-hf" ,
278
+ )
279
+ print ("Tokenizer loaded" )
280
+
281
+
282
+ blocksize = 128
283
+ percdamp = 0.01
284
+ groupsize = 64
285
+ calibration_tasks = ["wikitext" ]
286
+ calibration_limit = 1
287
+ calibration_seq_length = 100
288
+ input_prep_func = prepare_inputs_for_model
289
+ pad_calibration_inputs = False
290
+ print ("Recording inputs" )
291
+ inputs = InputRecorder (
292
+ tokenizer ,
293
+ calibration_seq_length ,
294
+ input_prep_func ,
295
+ pad_calibration_inputs ,
296
+ model .config .vocab_size ,
297
+ device = "cpu" ,
298
+ ).record_inputs (
299
+ calibration_tasks ,
300
+ calibration_limit ,
301
+ ).get_inputs ()
302
+ print ("Inputs recorded" )
303
+ quantizer = Int4WeightOnlyGPTQQuantizer (
304
+ blocksize ,
305
+ percdamp ,
306
+ groupsize ,
307
+ )
308
+
309
+ model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
310
+ multi = [
311
+ MultiTensor ([ inp for inp , _ in inputs ]),
312
+ MultiTensor ([ inds for _ , inds in inputs ])
313
+ ]
314
+ print ("Quantizing model" )
315
+ model = quantizer .quantize (model , multi ).cuda ()
316
+ print ("Model quantized" )
317
+ print ("Saving model and fixing state dict" )
318
+ regular_state_dict = model .state_dict ()#defaultdict(torch.tensor)
319
+ for key , value in model .state_dict ().items ():
320
+ if isinstance (value , MultiTensor ):
321
+ regular_state_dict [key ] = value .values [0 ]
322
+ else :
323
+ regular_state_dict [key ] = value
324
+
325
+ model = Transformer .from_name (checkpoint_path .parent .name )
326
+ remove = [k for k in regular_state_dict if "kv_cache" in k ]
327
+ for k in remove :
328
+ del regular_state_dict [k ]
329
+
330
+ model .load_state_dict (regular_state_dict , assign = True )
331
+ torch .save (model .state_dict (), 'model.pth' )
332
+ print ("Running evaluation" )
333
+ result = TransformerEvalWrapper (
334
+ model .to ("cpu" ),
335
+ tokenizer ,
336
+ model .config .block_size ,
337
+ prepare_inputs_for_model ,
338
+ "cpu" ,
339
+ ).run_eval (
340
+ ["wikitext" ],
341
+ 1 ,
342
+ )
343
+ print (result )
0 commit comments