Skip to content

Commit d953902

Browse files
navsudfacebook-github-bot
authored andcommitted
create staticmethod for quantizing weights of QATLinear and QATEmbedding
Summary: For saving the quantized weights, we have been using adhoc notebooks with copy-pasted code from the convert method. This had been a source of numerical discrepancies. To avoid this issue, this diff adds separates the weight quantization logic in to a separate staticmethods so that we can reuse it. Reviewed By: jerryzh168 Differential Revision: D73201409
1 parent 34421b1 commit d953902

File tree

2 files changed

+69
-51
lines changed

2 files changed

+69
-51
lines changed

torchao/quantization/qat/embedding.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Tuple
88

99
import torch
1010
import torch.nn.functional as F
@@ -15,9 +15,7 @@
1515

1616
from .api import FakeQuantizeConfig
1717
from .fake_quantizer import FakeQuantizer
18-
from .utils import (
19-
_get_qmin_qmax,
20-
)
18+
from .utils import _get_qmin_qmax
2119

2220

2321
class FakeQuantizedEmbedding(torch.nn.Embedding):
@@ -196,15 +194,40 @@ def convert(
196194
"""
197195
self._convert_helper(model)
198196
return model
197+
198+
@staticmethod
199+
def quantize_weights(
200+
weight: torch.Tensor,
201+
bit_width: int,
202+
group_size: int,
203+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
204+
"""
205+
Helper function to quantize weights
206+
"""
207+
(qmin, qmax) = _get_qmin_qmax(bit_width)
208+
(s, zp) = get_group_qparams_symmetric(
209+
weight, bit_width, group_size
210+
)
211+
from torchao._executorch_ops import (
212+
_quantized_decomposed_quantize_per_channel_group_wrapper,
213+
)
214+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
215+
weight,
216+
s,
217+
zp,
218+
qmin,
219+
qmax,
220+
torch.int8,
221+
group_size,
222+
)
223+
return (q_weight, s, zp)
224+
199225

200226
def _convert_helper(self, module: torch.nn.Module):
201227
"""
202228
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
203229
modules with `Int4WeightOnlyEmbedding`
204230
"""
205-
from torchao._executorch_ops import (
206-
_quantized_decomposed_quantize_per_channel_group_wrapper,
207-
)
208231

209232
for name, child in module.named_children():
210233
if isinstance(child, Int4WeightOnlyQATEmbedding):
@@ -230,20 +253,8 @@ def _convert_helper(self, module: torch.nn.Module):
230253
)
231254
setattr(module, name, quantized_embedding)
232255

256+
q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size)
233257
# Load weights and qparams into quantized embedding
234-
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
235-
(s, zp) = get_group_qparams_symmetric(
236-
child.weight, self.bit_width, group_size
237-
)
238-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
239-
child.weight,
240-
s,
241-
zp,
242-
qmin,
243-
qmax,
244-
torch.int8,
245-
group_size,
246-
)
247258
quantized_embedding.weight = q_weight
248259
quantized_embedding.scale = s.to(scale_precision)
249260
quantized_embedding.zero_point = zp.to(zero_point_precision)

torchao/quantization/qat/linear.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,28 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Tuple
88

99
import torch
1010
import torch.nn.functional as F
1111

1212
from torchao.dtypes.utils import is_device
1313
from torchao.quantization.GPTQ import (
14-
Int8DynActInt4WeightLinear,
15-
WeightOnlyInt4Linear,
1614
_check_linear_int4_k,
1715
_replace_linear_8da4w,
1816
_replace_linear_int4,
1917
groupwise_affine_quantize_tensor,
18+
Int8DynActInt4WeightLinear,
19+
WeightOnlyInt4Linear,
2020
)
21-
from torchao.quantization.quant_primitives import (
22-
TorchAODType,
23-
ZeroPointDomain,
24-
)
21+
from torchao.quantization.quant_primitives import TorchAODType, ZeroPointDomain
2522
from torchao.quantization.unified import TwoStepQuantizer
2623
from torchao.quantization.utils import get_group_qparams_symmetric
2724
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2825

2926
from .api import FakeQuantizeConfig
3027
from .fake_quantizer import FakeQuantizer
31-
from .utils import (
32-
_get_qmin_qmax,
33-
)
28+
from .utils import _get_qmin_qmax
3429

3530

3631
class FakeQuantizedLinear(torch.nn.Linear):
@@ -197,6 +192,36 @@ def convert(
197192
) -> torch.nn.Module:
198193
self._convert_qat_linear_8da4w(model)
199194
return model
195+
196+
@staticmethod
197+
def quantize_weights(
198+
weight: torch.Tensor,
199+
group_size: int,
200+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
201+
"""
202+
Helper function to quantize weights
203+
"""
204+
# Load weights and qparams into quantized linear
205+
n_bit = 4
206+
(qmin, qmax) = _get_qmin_qmax(n_bit)
207+
(s, zp) = get_group_qparams_symmetric(
208+
weight, n_bit, group_size
209+
)
210+
from torchao._executorch_ops import (
211+
_quantized_decomposed_quantize_per_channel_group_wrapper,
212+
)
213+
214+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
215+
weight,
216+
s,
217+
zp,
218+
qmin,
219+
qmax,
220+
torch.int8,
221+
group_size,
222+
)
223+
return (q_weight, s, zp)
224+
200225

201226
def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
202227
"""
@@ -215,28 +240,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
215240
)
216241
setattr(module, name, quantized_linear)
217242

218-
# Load weights and qparams into quantized linear
219-
n_bit = 4
220-
(qmin, qmax) = _get_qmin_qmax(n_bit)
221-
(s, zp) = get_group_qparams_symmetric(
222-
child.weight, n_bit, config.group_size
223-
)
224-
from torchao._executorch_ops import (
225-
_quantized_decomposed_quantize_per_channel_group_wrapper,
226-
)
227-
228-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
229-
child.weight,
230-
s,
231-
zp,
232-
qmin,
233-
qmax,
234-
torch.int8,
235-
config.group_size,
236-
)
243+
q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size)
237244
quantized_linear.weight = q_weight
238-
quantized_linear.scales = s
239-
quantized_linear.zeros = zp
245+
quantized_linear.scales = scales
246+
quantized_linear.zeros = zeros
240247
if child.bias is not None:
241248
quantized_linear.bias = child.bias
242249
else:

0 commit comments

Comments
 (0)