@@ -122,6 +122,10 @@ def dynamically_quantize_per_channel(
122
122
return quant , scales , zero_points
123
123
124
124
125
+ #########################################################################
126
+ ### QuantHandler API definition ###
127
+
128
+
125
129
class QuantHandler :
126
130
def __init__ (self , mod ):
127
131
self .mod = mod
@@ -132,8 +136,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
132
136
def convert_for_runtime (self ) -> nn .Module :
133
137
pass
134
138
139
+ def quantized_model (self ) -> nn .Module :
140
+ model_updated_state_dict = self .create_quantized_state_dict ()
141
+ self .convert_for_runtime ()
142
+ self .mod .load_state_dict (model_updated_state_dict )
143
+ return self .mod
135
144
136
- ##### Weight-only int8 per-channel quantized code ######
145
+
146
+ #########################################################################
147
+ ### Weight-only int8 per-channel quantized code ###
137
148
138
149
139
150
def replace_linear_weight_only_int8_per_channel (module , node_type ):
@@ -151,16 +162,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
151
162
setattr (
152
163
module ,
153
164
name ,
154
- WeightOnlyInt8Linear (child .in_features , child .out_features ),
165
+ WeightOnlyInt8Linear ("cpu" , child .in_features , child .out_features ),
155
166
)
156
167
else :
157
168
replace_linear_weight_only_int8_per_channel (child , node_type )
158
169
159
170
160
- class WeightOnlyInt8QuantHandler :
171
+ class WeightOnlyInt8QuantHandler ( QuantHandler ) :
161
172
def __init__ (
162
173
self ,
163
174
mod ,
175
+ device = "cpu" ,
164
176
* ,
165
177
node_type : str = "*" ,
166
178
bitwidth : Optional [int ] = None ,
@@ -200,7 +212,7 @@ def create_quantized_state_dict(self) -> Dict:
200
212
)
201
213
):
202
214
print (
203
- f"quantize { self .node_type } { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
215
+ f"quantize { self .node_type } { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
204
216
)
205
217
206
218
# print(f"initial weight shape {mod.weight.shape}")
@@ -217,7 +229,7 @@ def create_quantized_state_dict(self) -> Dict:
217
229
)
218
230
219
231
cur_state_dict [f"{ fqn } .weight" ] = weight
220
- # squeeze makes groupsize =rowsize unidimensional
232
+ # squeeze makes group_size =rowsize unidimensional
221
233
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
222
234
223
235
return cur_state_dict
@@ -241,10 +253,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
241
253
242
254
def __init__ (
243
255
self ,
256
+ device ,
244
257
in_features : int ,
245
258
out_features : int ,
246
259
bias : bool = True ,
247
- device = None ,
248
260
dtype = None ,
249
261
) -> None :
250
262
super ().__init__ ()
@@ -260,11 +272,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
260
272
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
261
273
262
274
263
- ##### embedding table quantization ######
275
+ #########################################################################
276
+ ##### embedding table quantization ######
264
277
265
278
266
279
def replace_embedding_weight_only_grouped_int8_per_channel (
267
- module , bitwidth : int = 8 , group_size : Optional [int ] = None
280
+ module , device , bitwidth : int = 8 , group_size : Optional [int ] = None , packed = False
268
281
):
269
282
for name , child in module .named_children ():
270
283
# print(f"name: {name}")
@@ -275,25 +288,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
275
288
module ,
276
289
name ,
277
290
QuantizedGroupEmbedding (
291
+ device = device ,
278
292
vocab_size = child .weight .shape [0 ],
279
293
embedding_dim = child .weight .shape [1 ],
280
294
group_size = group_size ,
295
+ packed = packed ,
281
296
),
282
297
)
283
298
else :
284
299
replace_embedding_weight_only_grouped_int8_per_channel (
285
- child , bitwidth , group_size
300
+ child , device , bitwidth , group_size , packed
286
301
)
287
302
288
303
289
- class EmbeddingOnlyInt8QuantHandler :
290
- def __init__ (self , mod , * , bitwidth : int = 8 , group_size : Optional [int ] = None ):
304
+ class EmbeddingOnlyInt8QuantHandler (QuantHandler ):
305
+ def __init__ (
306
+ self ,
307
+ mod ,
308
+ device = "cpu" ,
309
+ * ,
310
+ bitwidth : int = 8 ,
311
+ group_size : Optional [int ] = None ,
312
+ packed = False ,
313
+ ):
314
+ if isinstance (packed , str ):
315
+ packed = packed == "True"
291
316
self .mod = mod
317
+ self .device = device
292
318
self .group_size = group_size
293
319
self .bitwidth = bitwidth
320
+ self .packed = packed
321
+ if (bitwidth != 4 ) and packed :
322
+ raise RuntimeError ("pack only works with bitsize 4" )
294
323
295
324
@torch .no_grad ()
296
- def create_quantized_state_dict (self ) -> Dict :
325
+ def create_quantized_state_dict (self , packed = False ) -> Dict :
297
326
cur_state_dict = self .mod .state_dict ()
298
327
299
328
if self .bitwidth == 4 :
@@ -306,18 +335,14 @@ def create_quantized_state_dict(self) -> Dict:
306
335
raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
307
336
308
337
for fqn , mod in self .mod .named_modules ():
309
- if (
310
- isinstance (mod , nn .Embedding )
311
- or isinstance (mod , fsEmbedding )
312
- or isinstance (mod , fsStandardEmbedding )
313
- ):
338
+ if isinstance (mod , nn .Embedding ):
314
339
# print("****")
315
340
# print(f"Embedding identified: {fqn, mod}")
316
341
# print(f"weights size: {mod.weight.size()}")
317
342
# print(f"quantize {fqn}...")
318
343
319
344
print (
320
- f"quantize { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
345
+ f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
321
346
)
322
347
weight , scales , _ = dynamically_quantize_per_channel (
323
348
mod .weight .float (),
@@ -328,21 +353,36 @@ def create_quantized_state_dict(self) -> Dict:
328
353
scales_dtype = mod .weight .dtype ,
329
354
)
330
355
356
+ if packed :
357
+ if weight .shape [- 1 ] % 2 != 0 :
358
+ raise RuntimeError ("automatic padding not implemented yet" )
359
+
360
+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
361
+ weight_view = weight_range_shifted .view (
362
+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
363
+ )
364
+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
365
+ weight_odd = weight_view [:, :, 1 ]
366
+ weight_packed = weight_even + weight_odd
367
+ weight = weight_packed
368
+
369
+ weight = weight .to (device = self .device )
370
+ scales = scales .to (device = self .device )
331
371
# Update state dict
332
372
cur_state_dict [f"{ fqn } .weight" ] = weight
333
- # squeeze makes groupsize =rowsize unidimensional
373
+ # squeeze makes group_size =rowsize unidimensional
334
374
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
335
375
336
376
return cur_state_dict
337
377
338
378
def convert_for_runtime (self ) -> nn .Module :
339
379
replace_embedding_weight_only_grouped_int8_per_channel (
340
- self .mod , self .bitwidth , self .group_size
380
+ self .mod , self .device , self . bitwidth , self .group_size , self . packed
341
381
)
342
382
return self .mod
343
383
344
384
def quantized_model (self ) -> nn .Module :
345
- model_updated_state_dict = self .create_quantized_state_dict ()
385
+ model_updated_state_dict = self .create_quantized_state_dict (self . packed )
346
386
self .convert_for_runtime ()
347
387
self .mod .load_state_dict (model_updated_state_dict )
348
388
return self .mod
@@ -351,39 +391,53 @@ def quantized_model(self) -> nn.Module:
351
391
class QuantizedGroupEmbedding (torch .nn .Module ):
352
392
def __init__ (
353
393
self ,
394
+ device ,
354
395
vocab_size : int ,
355
396
embedding_dim : int ,
356
397
group_size : Optional [int ] = None ,
357
- device = None ,
358
398
dtype = torch .half ,
399
+ packed = False ,
359
400
) -> None :
360
401
super ().__init__ ()
361
- if group_size is None :
402
+ if group_size is None or group_size == 0 :
362
403
group_size = embedding_dim
363
404
self .group_size = group_size
364
405
self .dtype = dtype
365
- self .register_buffer (
366
- "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
367
- )
406
+ self .packed = packed
407
+ if not packed :
408
+ self .register_buffer (
409
+ "weight" ,
410
+ torch .empty (
411
+ (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
412
+ ),
413
+ )
414
+ else : # packed
415
+ self .register_buffer (
416
+ "weight" ,
417
+ torch .empty (
418
+ (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
419
+ ),
420
+ )
368
421
groups_per_row = (embedding_dim + group_size - 1 ) // group_size
369
422
if groups_per_row > 1 :
370
423
self .register_buffer (
371
- "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
424
+ "scales" ,
425
+ torch .ones (
426
+ (vocab_size , groups_per_row ), dtype = torch .float16 , device = device
427
+ ),
372
428
)
373
429
else :
374
430
self .register_buffer (
375
- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 )
431
+ "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
376
432
)
377
433
378
434
@torch .no_grad ()
379
435
def forward (self , indices : torch .Tensor ) -> torch .Tensor :
380
- return torch .ops .llama_quantized .DEPRECATED_DO_NOT_USE_embedding_byte .dtype (
381
- self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
382
- )
383
-
384
-
385
- # result_weights = self.weight.index_select(0, indices.view(-1))
386
- # result_scales = self.scales.index_select(0, indices.view(-1))
387
- #
388
- # r = result_weights.to(dtype=result_scales.dtype) * result_scales
389
- # return r.view(indices.size() + (-1,))
436
+ if not self .packed : # 8bit
437
+ return torch .ops .llama_quantized .DEPRECATED_DO_NOT_USE_embedding_byte .dtype (
438
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
439
+ )
440
+ else : # 4bit packed
441
+ return torch .ops .llama_quantized .embedding_4bit .dtype (
442
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
443
+ )
0 commit comments