@@ -267,7 +267,7 @@ def __init__(self):
267
267
self .custom_quant_annotations : Sequence [Callable ] = []
268
268
self .discard_nodes : Set [str ] = set ()
269
269
270
- self .enable_per_channel_conv_quant : bool = True
270
+ self .use_per_channel_weight_quant_ops : Set [ OpOverload ] = set ()
271
271
# the weight quantized for activation 8 bits and 16 bits
272
272
self .per_channel_weight_dtype : Dict = {
273
273
"8bit_act" : torch .int8 ,
@@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
290
290
def _get_quant_config (self , op : str | OpOverload ) -> Optional [QuantizationConfig ]:
291
291
"""
292
292
Priority:
293
- 1. per channel config when enable_per_channel_conv_quant is True
293
+ 1. is one of use_per_channel_weight_quant_ops
294
294
2. int8 / int16 config
295
295
"""
296
296
if type (op ) == str :
297
297
return
298
298
299
- if self .enable_per_channel_conv_quant and op in [
300
- torch .ops .aten .conv1d .default ,
301
- torch .ops .aten .conv2d .default ,
302
- ]:
299
+ if op in self .use_per_channel_weight_quant_ops :
303
300
if op in self .bit16_quant_ops :
304
301
return get_ptq_per_channel_weight_config (
305
302
torch .uint16 , self .per_channel_weight_dtype ["16bit_act" ]
@@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
316
313
317
314
print (f"No quant config is implemented for op, { op } " )
318
315
316
+ def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
317
+ if enable :
318
+ self .use_per_channel_weight_quant_ops .update (ops )
319
+ else :
320
+ self .use_per_channel_weight_quant_ops .difference (ops )
321
+
319
322
def add_16bit_quant_ops (self , ops : Set [OpOverload ]) -> None :
320
323
for op in ops :
321
324
assert (
@@ -368,8 +371,15 @@ def set_per_channel_weight_dtype(
368
371
if weight_dtype_for_16bit_act :
369
372
self .per_channel_weight_dtype ["16bit_act" ] = weight_dtype_for_16bit_act
370
373
371
- def set_per_channel_quant (self , enable : bool ) -> None :
372
- self .enable_per_channel_conv_quant = enable
374
+ def set_per_channel_conv_quant (self , enable : bool ) -> None :
375
+ conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
376
+ self ._update_per_channel_weight_quant_ops (conv_ops , enable )
377
+
378
+ def set_per_channel_linear_quant (self , enable : bool ) -> None :
379
+ linear_ops = {
380
+ torch .ops .aten .linear .default ,
381
+ }
382
+ self ._update_per_channel_weight_quant_ops (linear_ops , enable )
373
383
374
384
def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
375
385
model = RemoveClone ()(model ).graph_module
0 commit comments