@@ -124,6 +124,10 @@ def dynamically_quantize_per_channel(
124124 return quant , scales , zero_points
125125
126126
127+ #########################################################################
128+ ### QuantHandler API definition ###
129+
130+
127131class QuantHandler :
128132 def __init__ (self , mod ):
129133 self .mod = mod
@@ -134,8 +138,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
134138 def convert_for_runtime (self ) -> nn .Module :
135139 pass
136140
141+ def quantized_model (self ) -> nn .Module :
142+ model_updated_state_dict = self .create_quantized_state_dict ()
143+ self .convert_for_runtime ()
144+ self .mod .load_state_dict (model_updated_state_dict )
145+ return self .mod
137146
138- ##### Weight-only int8 per-channel quantized code ######
147+
148+ #########################################################################
149+ ### Weight-only int8 per-channel quantized code ###
139150
140151
141152def replace_linear_weight_only_int8_per_channel (module , node_type ):
@@ -153,16 +164,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
153164 setattr (
154165 module ,
155166 name ,
156- WeightOnlyInt8Linear (child .in_features , child .out_features ),
167+ WeightOnlyInt8Linear ("cpu" , child .in_features , child .out_features ),
157168 )
158169 else :
159170 replace_linear_weight_only_int8_per_channel (child , node_type )
160171
161172
162- class WeightOnlyInt8QuantHandler :
173+ class WeightOnlyInt8QuantHandler ( QuantHandler ) :
163174 def __init__ (
164175 self ,
165176 mod ,
177+ device = "cpu" ,
166178 * ,
167179 node_type : str = "*" ,
168180 bitwidth : Optional [int ] = None ,
@@ -202,7 +214,7 @@ def create_quantized_state_dict(self) -> Dict:
202214 )
203215 ):
204216 print (
205- f"quantize { self .node_type } { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
217+ f"quantize { self .node_type } { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
206218 )
207219
208220 # print(f"initial weight shape {mod.weight.shape}")
@@ -219,7 +231,7 @@ def create_quantized_state_dict(self) -> Dict:
219231 )
220232
221233 cur_state_dict [f"{ fqn } .weight" ] = weight
222- # squeeze makes groupsize =rowsize unidimensional
234+ # squeeze makes group_size =rowsize unidimensional
223235 cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
224236
225237 return cur_state_dict
@@ -243,10 +255,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
243255
244256 def __init__ (
245257 self ,
258+ device ,
246259 in_features : int ,
247260 out_features : int ,
248261 bias : bool = True ,
249- device = None ,
250262 dtype = None ,
251263 ) -> None :
252264 super ().__init__ ()
@@ -262,11 +274,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
262274 # return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
263275
264276
265- ##### embedding table quantization ######
277+ #########################################################################
278+ ##### embedding table quantization ######
266279
267280
268281def replace_embedding_weight_only_grouped_int8_per_channel (
269- module , bitwidth : int = 8 , group_size : Optional [int ] = None
282+ module , device , bitwidth : int = 8 , group_size : Optional [int ] = None , packed = False
270283):
271284 for name , child in module .named_children ():
272285 # print(f"name: {name}")
@@ -277,25 +290,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
277290 module ,
278291 name ,
279292 QuantizedGroupEmbedding (
293+ device = device ,
280294 vocab_size = child .weight .shape [0 ],
281295 embedding_dim = child .weight .shape [1 ],
282296 group_size = group_size ,
297+ packed = packed ,
283298 ),
284299 )
285300 else :
286301 replace_embedding_weight_only_grouped_int8_per_channel (
287- child , bitwidth , group_size
302+ child , device , bitwidth , group_size , packed
288303 )
289304
290305
291- class EmbeddingOnlyInt8QuantHandler :
292- def __init__ (self , mod , * , bitwidth : int = 8 , group_size : Optional [int ] = None ):
306+ class EmbeddingQuantHandler (QuantHandler ):
307+ def __init__ (
308+ self ,
309+ mod ,
310+ device = "cpu" ,
311+ * ,
312+ bitwidth : int = 8 ,
313+ group_size : Optional [int ] = None ,
314+ packed = False ,
315+ ):
316+ if isinstance (packed , str ):
317+ packed = packed == "True"
293318 self .mod = mod
319+ self .device = device
294320 self .group_size = group_size
295321 self .bitwidth = bitwidth
322+ self .packed = packed
323+ if (bitwidth != 4 ) and packed :
324+ raise RuntimeError ("pack only works with bitsize 4" )
296325
297326 @torch .no_grad ()
298- def create_quantized_state_dict (self ) -> Dict :
327+ def create_quantized_state_dict (self , packed = False ) -> Dict :
299328 cur_state_dict = self .mod .state_dict ()
300329
301330 if self .bitwidth == 4 :
@@ -308,18 +337,14 @@ def create_quantized_state_dict(self) -> Dict:
308337 raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
309338
310339 for fqn , mod in self .mod .named_modules ():
311- if (
312- isinstance (mod , nn .Embedding )
313- or isinstance (mod , fsEmbedding )
314- or isinstance (mod , fsStandardEmbedding )
315- ):
340+ if isinstance (mod , nn .Embedding ):
316341 # print("****")
317342 # print(f"Embedding identified: {fqn, mod}")
318343 # print(f"weights size: {mod.weight.size()}")
319344 # print(f"quantize {fqn}...")
320345
321346 print (
322- f"quantize { fqn , mod } with groupsize { self .group_size } , bitwidth { self .bitwidth } "
347+ f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
323348 )
324349 weight , scales , _ = dynamically_quantize_per_channel (
325350 mod .weight .float (),
@@ -330,21 +355,36 @@ def create_quantized_state_dict(self) -> Dict:
330355 scales_dtype = mod .weight .dtype ,
331356 )
332357
358+ if packed :
359+ if weight .shape [- 1 ] % 2 != 0 :
360+ raise RuntimeError ("automatic padding not implemented yet" )
361+
362+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
363+ weight_view = weight_range_shifted .view (
364+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
365+ )
366+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
367+ weight_odd = weight_view [:, :, 1 ]
368+ weight_packed = weight_even + weight_odd
369+ weight = weight_packed
370+
371+ weight = weight .to (device = self .device )
372+ scales = scales .to (device = self .device )
333373 # Update state dict
334374 cur_state_dict [f"{ fqn } .weight" ] = weight
335- # squeeze makes groupsize =rowsize unidimensional
375+ # squeeze makes group_size =rowsize unidimensional
336376 cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
337377
338378 return cur_state_dict
339379
340380 def convert_for_runtime (self ) -> nn .Module :
341381 replace_embedding_weight_only_grouped_int8_per_channel (
342- self .mod , self .bitwidth , self .group_size
382+ self .mod , self .device , self . bitwidth , self .group_size , self . packed
343383 )
344384 return self .mod
345385
346386 def quantized_model (self ) -> nn .Module :
347- model_updated_state_dict = self .create_quantized_state_dict ()
387+ model_updated_state_dict = self .create_quantized_state_dict (self . packed )
348388 self .convert_for_runtime ()
349389 self .mod .load_state_dict (model_updated_state_dict )
350390 return self .mod
@@ -353,39 +393,53 @@ def quantized_model(self) -> nn.Module:
353393class QuantizedGroupEmbedding (torch .nn .Module ):
354394 def __init__ (
355395 self ,
396+ device ,
356397 vocab_size : int ,
357398 embedding_dim : int ,
358399 group_size : Optional [int ] = None ,
359- device = None ,
360400 dtype = torch .half ,
401+ packed = False ,
361402 ) -> None :
362403 super ().__init__ ()
363- if group_size is None :
404+ if group_size is None or group_size == 0 :
364405 group_size = embedding_dim
365406 self .group_size = group_size
366407 self .dtype = dtype
367- self .register_buffer (
368- "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
369- )
408+ self .packed = packed
409+ if not packed :
410+ self .register_buffer (
411+ "weight" ,
412+ torch .empty (
413+ (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
414+ ),
415+ )
416+ else : # packed
417+ self .register_buffer (
418+ "weight" ,
419+ torch .empty (
420+ (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
421+ ),
422+ )
370423 groups_per_row = (embedding_dim + group_size - 1 ) // group_size
371424 if groups_per_row > 1 :
372425 self .register_buffer (
373- "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
426+ "scales" ,
427+ torch .ones (
428+ (vocab_size , groups_per_row ), dtype = torch .float16 , device = device
429+ ),
374430 )
375431 else :
376432 self .register_buffer (
377- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 )
433+ "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
378434 )
379435
380436 @torch .no_grad ()
381437 def forward (self , indices : torch .Tensor ) -> torch .Tensor :
382- return torch .ops .quantized_decomposed .embedding_byte .dtype (
383- self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
384- )
385-
386-
387- # result_weights = self.weight.index_select(0, indices.view(-1))
388- # result_scales = self.scales.index_select(0, indices.view(-1))
389- #
390- # r = result_weights.to(dtype=result_scales.dtype) * result_scales
391- # return r.view(indices.size() + (-1,))
438+ if not self .packed : # 8bit
439+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
440+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
441+ )
442+ else : # 4bit packed
443+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
444+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
445+ )
0 commit comments