Skip to content

Commit fa52f16

Browse files
committed
Initial support for 8da4w QAT
Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group python test/quantization/test_qat.py -k test_fake_quantize_per_token python test/quantization/test_qat.py -k test_qat_8da4w_linear python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch, HDCharles Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar Tasks: #86
1 parent 5401df0 commit fa52f16

File tree

3 files changed

+467
-2
lines changed

3 files changed

+467
-2
lines changed

test/quantization/test_qat.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
# This test takes a long time to run
9+
10+
import copy
11+
import unittest
12+
13+
import torch
14+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
15+
from torchao.quantization._prototype.qat import (
16+
_choose_qparams_per_token_asymmetric,
17+
fake_quantize_per_channel_group,
18+
fake_quantize_per_token,
19+
Int8DynActInt4WeightQATLinear,
20+
Int8DynActInt4WeightQATQuantizer,
21+
)
22+
from torchao.quantization.quant_primitives import (
23+
get_group_qparams_symmetric,
24+
group_quantize_tensor_symmetric,
25+
per_token_dynamic_quant,
26+
)
27+
from torchao.quantization.utils import (
28+
TORCH_VERSION_AFTER_2_3,
29+
)
30+
from torchao.quantization.GPTQ import (
31+
Int8DynActInt4WeightLinear,
32+
Int8DynActInt4WeightQuantizer,
33+
)
34+
35+
36+
# TODO: put this in a common test utils file
37+
class M(torch.nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
41+
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
42+
43+
def example_inputs(self):
44+
return (torch.randn(1, 64).to(torch.float),)
45+
46+
def forward(self, x):
47+
x = self.linear1(x)
48+
x = self.linear2(x)
49+
return x
50+
51+
52+
class TestQAT(unittest.TestCase):
53+
SEED = 123
54+
55+
def _get_qmin_qmax(self, n_bit: int):
56+
qmin = -(2 ** (n_bit - 1))
57+
qmax = 2 ** (n_bit - 1) - 1
58+
return (qmin, qmax)
59+
60+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
61+
def test_fake_quantize_per_channel_group(self):
62+
n_bit = 4
63+
(qmin, qmax) = self._get_qmin_qmax(n_bit)
64+
group_size = 128
65+
66+
torch.manual_seed(self.SEED)
67+
x = torch.randn(100, 256).requires_grad_()
68+
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
69+
x2 = copy.deepcopy(x)
70+
71+
# fake quant op
72+
out = fake_quantize_per_channel_group(
73+
x, s, zp, qmin, qmax, group_size,
74+
)
75+
out.sum().backward()
76+
77+
# compare against PTQ ops
78+
out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group(
79+
x2, s, zp, qmin, qmax, torch.int8, group_size,
80+
)
81+
out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
82+
out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32,
83+
)
84+
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
85+
86+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
87+
def test_fake_quantize_per_token(self):
88+
(qmin, qmax) = self._get_qmin_qmax(8)
89+
90+
torch.manual_seed(self.SEED)
91+
x = torch.randn(100, 256).requires_grad_()
92+
x2 = copy.deepcopy(x)
93+
# TODO: use torch.ops.aten.quantized_decomposed version instead
94+
(s, zp) = _choose_qparams_per_token_asymmetric(
95+
x,
96+
torch.int8, # not used
97+
)
98+
99+
# fake quant op
100+
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
101+
out.sum().backward()
102+
103+
# compare against PTQ ops
104+
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
105+
x2, s, zp, qmin, qmax, torch.int8,
106+
)
107+
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
108+
out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32,
109+
)
110+
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
111+
112+
def _set_ptq_weight(
113+
self,
114+
ptq_linear: Int8DynActInt4WeightLinear,
115+
fp32_weight: torch.Tensor,
116+
group_size: int,
117+
):
118+
"""
119+
Set the weight to the quantized version of the given fp32 weights,
120+
for making linear outputs comparable with QAT.
121+
"""
122+
n_bit = 4
123+
(qmin, qmax) = self._get_qmin_qmax(n_bit)
124+
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
125+
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
126+
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
127+
)
128+
ptq_linear.weight = q_weight
129+
ptq_linear.scales = s
130+
ptq_linear.zeros = zp
131+
132+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
133+
def test_qat_8da4w_linear(self):
134+
group_size = 128
135+
torch.manual_seed(self.SEED)
136+
qat_linear = Int8DynActInt4WeightQATLinear(256, 688, bias=False, groupsize=group_size)
137+
ptq_linear = Int8DynActInt4WeightLinear(256, 688, bias=False, groupsize=group_size)
138+
139+
# Force the weights to be the same
140+
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
141+
142+
# Compare linear values
143+
torch.manual_seed(self.SEED)
144+
x = torch.randn(100, 256)
145+
x2 = copy.deepcopy(x)
146+
qat_out = qat_linear(x)
147+
ptq_out = ptq_linear(x2)
148+
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
149+
150+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
151+
def test_qat_8da4w_quantizer(self):
152+
group_size = 16
153+
torch.manual_seed(self.SEED)
154+
m = M()
155+
m2 = copy.deepcopy(m)
156+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
157+
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
158+
qat_model = qat_quantizer.prepare(m)
159+
ptq_model = ptq_quantizer.quantize(m2)
160+
161+
# Force the weights to be the same
162+
self._set_ptq_weight(
163+
ptq_model.linear1, qat_model.linear1.weight, group_size,
164+
)
165+
self._set_ptq_weight(
166+
ptq_model.linear2, qat_model.linear2.weight, group_size,
167+
)
168+
169+
# Compare model values
170+
torch.manual_seed(self.SEED)
171+
x = m.example_inputs()
172+
x2 = copy.deepcopy(x)
173+
qat_out = qat_model(*x)
174+
ptq_out = ptq_model(*x2)
175+
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
176+
177+
178+
if __name__ == "__main__":
179+
unittest.main()

test/quantization/test_quant_primitives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import unittest
1010
import torch
1111
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
12-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
12+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
1313

1414
class TestQuantPrimitives(unittest.TestCase):
1515
SEED = 123
1616

17-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
17+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
1818
def test_get_group_qparams_symmetric(self):
1919
"""
2020
Test that `get_group_qparams_symmetric` produces the exact same scales as

0 commit comments

Comments
 (0)