1919 Float8RowwiseParallel ,
2020)
2121from torchao .quantization import (
22- int4_weight_only ,
23- int8_dynamic_activation_int4_weight ,
22+ Int4WeightOnlyConfig ,
23+ Int8DynamicActivationInt4WeightConfig ,
2424 quantize_ ,
2525)
2626from torchao .quantization .qat import (
2727 Int4WeightOnlyQATQuantizer ,
2828 Int8DynActInt4WeightQATQuantizer ,
2929)
30- from torchao .quantization .qat .linear import (
31- disable_4w_fake_quant ,
32- disable_8da4w_fake_quant ,
33- enable_4w_fake_quant ,
34- enable_8da4w_fake_quant ,
35- )
30+
3631
3732from torchtune .modules .peft .lora import LoRALinear , QATLoRALinear
3833
5853_quantizer_mode_to_enable_fake_quant = {}
5954
6055
56+ def _enable_linear_fake_quant (
57+ mod : torch .nn .Module ,
58+ enabled : bool = True ,
59+ ):
60+ """
61+ Helper function to enable fake quantization in `FakeQuantizedLinear`.
62+ """
63+ if isinstance (mod , FakeQuantizedLinear ):
64+ if mod .activation_fake_quantizer is not None :
65+ mod .activation_fake_quantizer .enabled = enabled
66+ if mod .weight_fake_quantizer is not None :
67+ mod .weight_fake_quantizer .enabled = enabled
68+
69+
70+ def _disable_linear_fake_quant (mod : torch .nn .Module ):
71+ _enable_linear_fake_quant (mod , enabled = False )
72+
73+
6174# ========================================
6275# int8 dynamic activations + int4 weight |
6376# ========================================
@@ -73,15 +86,15 @@ def __init__(self, groupsize: int = 256):
7386 self .groupsize = groupsize
7487
7588 def quantize (self , model ):
76- quantize_fn = int8_dynamic_activation_int4_weight (self .groupsize )
89+ quantize_fn = Int8DynamicActivationInt4WeightConfig (self .groupsize )
7790 quantize_ (model , quantize_fn )
7891 return model
7992
8093
8194_quantizer_to_mode [Int8DynActInt4WeightQuantizer ] = "8da4w"
8295_quantizer_to_mode [Int8DynActInt4WeightQATQuantizer ] = "8da4w-qat"
83- _quantizer_mode_to_disable_fake_quant ["8da4w-qat" ] = disable_8da4w_fake_quant
84- _quantizer_mode_to_enable_fake_quant ["8da4w-qat" ] = enable_8da4w_fake_quant
96+ _quantizer_mode_to_disable_fake_quant ["8da4w-qat" ] = _disable_linear_fake_quant
97+ _quantizer_mode_to_enable_fake_quant ["8da4w-qat" ] = _enable_linear_fake_quant
8598
8699
87100# ==================
@@ -101,15 +114,15 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
101114
102115 def quantize (self , model ):
103116 layout_type = TensorCoreTiledLayout (self .inner_k_tiles )
104- quantize_fn = int4_weight_only (self .groupsize , layout_type )
117+ quantize_fn = Int4WeightOnlyConfig (self .groupsize , layout_type )
105118 quantize_ (model , quantize_fn )
106119 return model
107120
108121
109122_quantizer_to_mode [Int4WeightOnlyQuantizer ] = "4w"
110123_quantizer_to_mode [Int4WeightOnlyQATQuantizer ] = "4w-qat"
111- _quantizer_mode_to_disable_fake_quant ["4w-qat" ] = disable_4w_fake_quant
112- _quantizer_mode_to_enable_fake_quant ["4w-qat" ] = enable_4w_fake_quant
124+ _quantizer_mode_to_disable_fake_quant ["4w-qat" ] = _disable_linear_fake_quant
125+ _quantizer_mode_to_enable_fake_quant ["4w-qat" ] = _enable_linear_fake_quant
113126
114127
115128# ====================== #
@@ -122,8 +135,8 @@ class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
122135 pass
123136
124137
125- disable_4w_fake_quant_module_swap = disable_4w_fake_quant
126- enable_4w_fake_quant_module_swap = enable_4w_fake_quant
138+ disable_4w_fake_quant_module_swap = _disable_linear_fake_quant
139+ enable_4w_fake_quant_module_swap = _enable_linear_fake_quant
127140_quantizer_to_mode [Int4WeightOnlyQATQuantizerModuleSwap ] = "4w-qat-module-swap"
128141_quantizer_mode_to_disable_fake_quant [
129142 "4w-qat-module-swap"
@@ -138,8 +151,8 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
138151 pass
139152
140153
141- disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
142- enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
154+ disable_8da4w_fake_quant_module_swap = _disable_linear_fake_quant
155+ enable_8da4w_fake_quant_module_swap = _enable_linear_fake_quant
143156_quantizer_to_mode [Int8DynActInt4WeightQATQuantizerModuleSwap ] = "8da4w-qat-module-swap"
144157_quantizer_mode_to_disable_fake_quant [
145158 "8da4w-qat-module-swap"
0 commit comments