@@ -141,8 +141,7 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
141141 )
142142
143143 if (
144- self .name () == "wint4"
145- and _ENABLE_MACHETE
144+ _ENABLE_MACHETE
146145 and envs .FD_USE_MACHETE == "1"
147146 and layer .weight_shape [1 ]
148147 and layer .weight_shape [1 ] % 128 == 0
@@ -219,12 +218,22 @@ def create_weights(self, layer, **extra_weight_attrs):
219218 quant_attrs ,
220219 )
221220 else :
222- # The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
223- weight_scale_shape = [layer .weight_shape [1 ]]
224- layer .weight_shape .reverse ()
225- if self .quant_config .name () == "wint4" :
226- layer .weight_shape [0 ] //= 2
227- layer .weight_dtype = "int8"
221+ if isinstance (self , MacheteWeightOnlyLinearMethod ):
222+ # Using group scale for machete, group size is 128
223+ weight_scale_shape = [(layer .weight_shape [0 ] + 127 ) // 128 , layer .weight_shape [1 ]]
224+ if self .quant_config .name () == "wint4" :
225+ layer .weight_shape [0 ] //= 8
226+ else :
227+ layer .weight_shape [0 ] //= 4
228+ layer .weight_dtype = "int32"
229+ else :
230+ # The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
231+ weight_scale_shape = [layer .weight_shape [1 ]]
232+ layer .weight_shape .reverse ()
233+ if self .quant_config .name () == "wint4" :
234+ layer .weight_shape [0 ] //= 2
235+ layer .weight_dtype = "int8"
236+
228237 layer .weight = layer .create_parameter (
229238 shape = layer .weight_shape ,
230239 dtype = layer .weight_dtype ,
@@ -260,17 +269,30 @@ def create_weights(self, layer, **extra_weight_attrs):
260269 def process_weights_after_loading (self , layer ) -> None :
261270 if not layer .fd_config .load_config .load_choices == "default_v1" :
262271 return
263- quanted_weight_tensor , weight_scale_tensor = weight_quantize (
264- layer .weight ,
265- algo = self .quant_config .algo ,
266- arch = self .quant_config .weight_only_linear_arch ,
267- )
272+ if isinstance (self , MacheteWeightOnlyLinearMethod ):
273+ from fastdeploy .model_executor .layers .quantization .ops import (
274+ machete_quantize_and_pack ,
275+ )
276+
277+ # Using group scale for machete, group size is 128
278+ quanted_weight_tensor , weight_scale_tensor = machete_quantize_and_pack (
279+ w = layer .weight ,
280+ atype = layer ._dtype ,
281+ quant_type = "uint4b8" if self .quant_config .name () == "wint4" else "uint8b128" ,
282+ group_size = 128 ,
283+ )
284+ else :
285+ quanted_weight_tensor , weight_scale_tensor = weight_quantize (
286+ layer .weight ,
287+ algo = self .quant_config .algo ,
288+ arch = self .quant_config .weight_only_linear_arch ,
289+ )
268290
269291 free_tensor (layer .weight )
270292
271293 layer .weight = layer .create_parameter (
272294 shape = quanted_weight_tensor .shape ,
273- dtype = "int8" ,
295+ dtype = "int8" if not isinstance ( self , MacheteWeightOnlyLinearMethod ) else "int32" ,
274296 is_bias = False ,
275297 default_initializer = paddle .nn .initializer .Constant (0 ),
276298 )
@@ -361,32 +383,6 @@ def __init__(
361383 ) -> None :
362384 super ().__init__ (quant_config )
363385
364- def create_weights (self , layer , ** extra_weight_attrs ):
365-
366- assert layer .bias is None , "Machete weight only linear method does not support bias."
367- assert self .quant_config .name () == "wint4" , "Machete weight only linear method only supports wint4."
368-
369- # The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
370- weight_scale_shape = [1 , layer .weight_shape [1 ]]
371-
372- # layer.weight_shape.reverse()
373- if self .quant_config .name () == "wint4" :
374- layer .weight_shape [0 ] //= 8
375- layer .weight_dtype = "int32"
376-
377- layer .weight = layer .create_parameter (
378- shape = layer .weight_shape ,
379- dtype = layer .weight_dtype ,
380- is_bias = False ,
381- default_initializer = paddle .nn .initializer .Constant (0 ),
382- )
383-
384- layer .weight_scale = layer .create_parameter (
385- shape = weight_scale_shape ,
386- dtype = layer ._dtype ,
387- is_bias = False ,
388- )
389-
390386 def process_prequanted_weights (self , layer , state_dict ) -> None :
391387 pass
392388
@@ -395,24 +391,27 @@ def process_loaded_weights(self, layer, weight) -> None:
395391 machete_quantize_and_pack ,
396392 )
397393
394+ # Using group scale for machete, group size is 128
398395 quanted_weight_tensor , weight_scale_tensor = machete_quantize_and_pack (
399396 w = weight ,
400397 atype = layer ._dtype ,
401- quant_type = "uint4b8" ,
398+ quant_type = "uint4b8" if self .quant_config .name () == "wint4" else "uint8b128" ,
399+ group_size = 128 ,
402400 )
403401 layer .weight .set_value (quanted_weight_tensor )
404402 layer .weight_scale .set_value (weight_scale_tensor .astype (paddle .get_default_dtype ()))
405403
406404 def apply (self , layer , x ):
407- assert layer .bias is None , "Machete weight only linear method does not support bias."
408- assert self .quant_config .name () == "wint4" , "Machete weight only linear method only supports wint4."
409405 from fastdeploy .model_executor .layers .quantization .ops import machete_wint_mm
410406
407+ # Using group scale for machete, group size is 128
411408 linear_out = machete_wint_mm (
412409 x ,
413410 w_prepack = layer .weight ,
414411 w_g_s = layer .weight_scale ,
415- weight_dtype = "uint4b8" ,
412+ weight_dtype = "uint4b8" if self .quant_config .name () == "wint4" else "uint8b128" ,
413+ group_size = 128 ,
416414 )
417-
415+ if layer .with_bias :
416+ linear_out = paddle .add (linear_out , layer .bias )
418417 return linear_out
0 commit comments