4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional
7
+ from typing import Any , Optional , Tuple
8
8
9
9
import torch
10
10
import torch .nn .functional as F
11
11
12
12
from torchao .dtypes .utils import is_device
13
13
from torchao .quantization .GPTQ import (
14
- Int8DynActInt4WeightLinear ,
15
- WeightOnlyInt4Linear ,
16
14
_check_linear_int4_k ,
17
15
_replace_linear_8da4w ,
18
16
_replace_linear_int4 ,
19
17
groupwise_affine_quantize_tensor ,
18
+ Int8DynActInt4WeightLinear ,
19
+ WeightOnlyInt4Linear ,
20
20
)
21
- from torchao .quantization .quant_primitives import (
22
- TorchAODType ,
23
- ZeroPointDomain ,
24
- )
21
+ from torchao .quantization .quant_primitives import TorchAODType , ZeroPointDomain
25
22
from torchao .quantization .unified import TwoStepQuantizer
26
23
from torchao .quantization .utils import get_group_qparams_symmetric
27
24
from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
28
25
29
26
from .api import FakeQuantizeConfig
30
27
from .fake_quantizer import FakeQuantizer
31
- from .utils import (
32
- _get_qmin_qmax ,
33
- )
28
+ from .utils import _get_qmin_qmax
34
29
35
30
36
31
class FakeQuantizedLinear (torch .nn .Linear ):
@@ -197,6 +192,36 @@ def convert(
197
192
) -> torch .nn .Module :
198
193
self ._convert_qat_linear_8da4w (model )
199
194
return model
195
+
196
+ @staticmethod
197
+ def quantize_weights (
198
+ weight : torch .Tensor ,
199
+ group_size : int ,
200
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
201
+ """
202
+ Helper function to quantize weights
203
+ """
204
+ # Load weights and qparams into quantized linear
205
+ n_bit = 4
206
+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
207
+ (s , zp ) = get_group_qparams_symmetric (
208
+ weight , n_bit , group_size
209
+ )
210
+ from torchao ._executorch_ops import (
211
+ _quantized_decomposed_quantize_per_channel_group_wrapper ,
212
+ )
213
+
214
+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
215
+ weight ,
216
+ s ,
217
+ zp ,
218
+ qmin ,
219
+ qmax ,
220
+ torch .int8 ,
221
+ group_size ,
222
+ )
223
+ return (q_weight , s , zp )
224
+
200
225
201
226
def _convert_qat_linear_8da4w (self , module : torch .nn .Module ):
202
227
"""
@@ -215,28 +240,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
215
240
)
216
241
setattr (module , name , quantized_linear )
217
242
218
- # Load weights and qparams into quantized linear
219
- n_bit = 4
220
- (qmin , qmax ) = _get_qmin_qmax (n_bit )
221
- (s , zp ) = get_group_qparams_symmetric (
222
- child .weight , n_bit , config .group_size
223
- )
224
- from torchao ._executorch_ops import (
225
- _quantized_decomposed_quantize_per_channel_group_wrapper ,
226
- )
227
-
228
- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
229
- child .weight ,
230
- s ,
231
- zp ,
232
- qmin ,
233
- qmax ,
234
- torch .int8 ,
235
- config .group_size ,
236
- )
243
+ q_weight , scales , zeros = self .quantize_weights (child .weight , config .group_size )
237
244
quantized_linear .weight = q_weight
238
- quantized_linear .scales = s
239
- quantized_linear .zeros = zp
245
+ quantized_linear .scales = scales
246
+ quantized_linear .zeros = zeros
240
247
if child .bias is not None :
241
248
quantized_linear .bias = child .bias
242
249
else :
0 commit comments