Skip to content

Commit d3f4a70

Browse files
authored
Initial support for 8da4w QAT (#138)
Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group python test/quantization/test_qat.py -k test_fake_quantize_per_token python test/quantization/test_qat.py -k test_qat_8da4w_linear python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch, HDCharles Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar Tasks: #86
1 parent 2bc1617 commit d3f4a70

File tree

5 files changed

+460
-12
lines changed

5 files changed

+460
-12
lines changed

test/quantization/test_qat.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
# This test takes a long time to run
9+
10+
import copy
11+
import unittest
12+
13+
import torch
14+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
15+
from torchao.quantization.prototype.qat import (
16+
_choose_qparams_per_token_asymmetric,
17+
fake_quantize_per_channel_group,
18+
fake_quantize_per_token,
19+
)
20+
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
21+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
22+
23+
24+
# TODO: put this in a common test utils file
25+
class M(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
29+
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
30+
31+
def example_inputs(self):
32+
return (torch.randn(1, 64).to(torch.float),)
33+
34+
def forward(self, x):
35+
x = self.linear1(x)
36+
x = self.linear2(x)
37+
return x
38+
39+
40+
class TestQAT(unittest.TestCase):
41+
SEED = 123
42+
43+
def _get_qmin_qmax(self, n_bit: int):
44+
qmin = -(2 ** (n_bit - 1))
45+
qmax = 2 ** (n_bit - 1) - 1
46+
return (qmin, qmax)
47+
48+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
49+
def test_fake_quantize_per_channel_group(self):
50+
n_bit = 4
51+
(qmin, qmax) = self._get_qmin_qmax(n_bit)
52+
group_size = 128
53+
54+
torch.manual_seed(self.SEED)
55+
x = torch.randn(100, 256).requires_grad_()
56+
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
57+
x2 = copy.deepcopy(x)
58+
59+
# fake quant op
60+
out = fake_quantize_per_channel_group(
61+
x, s, zp, qmin, qmax, group_size,
62+
)
63+
out.sum().backward()
64+
65+
# compare against PTQ ops
66+
out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group(
67+
x2, s, zp, qmin, qmax, torch.int8, group_size,
68+
)
69+
out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
70+
out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32,
71+
)
72+
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
73+
74+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
75+
def test_fake_quantize_per_token(self):
76+
(qmin, qmax) = self._get_qmin_qmax(8)
77+
78+
torch.manual_seed(self.SEED)
79+
x = torch.randn(100, 256).requires_grad_()
80+
x2 = copy.deepcopy(x)
81+
# TODO: use torch.ops.aten.quantized_decomposed version instead
82+
(s, zp) = _choose_qparams_per_token_asymmetric(
83+
x,
84+
torch.int8, # not used
85+
)
86+
87+
# fake quant op
88+
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
89+
out.sum().backward()
90+
91+
# compare against PTQ ops
92+
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
93+
x2, s, zp, qmin, qmax, torch.int8,
94+
)
95+
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
96+
out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32,
97+
)
98+
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
99+
100+
def _set_ptq_weight(
101+
self,
102+
ptq_linear: "Int8DynActInt4WeightLinear",
103+
fp32_weight: torch.Tensor,
104+
group_size: int,
105+
):
106+
"""
107+
Set the weight to the quantized version of the given fp32 weights,
108+
for making linear outputs comparable with QAT.
109+
"""
110+
n_bit = 4
111+
(qmin, qmax) = self._get_qmin_qmax(n_bit)
112+
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
113+
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
114+
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
115+
)
116+
ptq_linear.weight = q_weight
117+
ptq_linear.scales = s
118+
ptq_linear.zeros = zp
119+
120+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
121+
def test_qat_8da4w_linear(self):
122+
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
123+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
124+
125+
group_size = 128
126+
torch.manual_seed(self.SEED)
127+
qat_linear = Int8DynActInt4WeightQATLinear(
128+
256, 688, bias=False, groupsize=group_size,
129+
)
130+
ptq_linear = Int8DynActInt4WeightLinear(
131+
256, 688, bias=False, groupsize=group_size,
132+
)
133+
134+
# Force the weights to be the same
135+
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
136+
137+
# Compare linear values
138+
torch.manual_seed(self.SEED)
139+
x = torch.randn(100, 256)
140+
x2 = copy.deepcopy(x)
141+
qat_out = qat_linear(x)
142+
ptq_out = ptq_linear(x2)
143+
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
144+
145+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
146+
def test_qat_8da4w_quantizer(self):
147+
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
148+
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
149+
150+
group_size = 16
151+
torch.manual_seed(self.SEED)
152+
m = M()
153+
m2 = copy.deepcopy(m)
154+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
155+
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
156+
qat_model = qat_quantizer.prepare(m)
157+
ptq_model = ptq_quantizer.quantize(m2)
158+
159+
# Force the weights to be the same
160+
self._set_ptq_weight(
161+
ptq_model.linear1, qat_model.linear1.weight, group_size,
162+
)
163+
self._set_ptq_weight(
164+
ptq_model.linear2, qat_model.linear2.weight, group_size,
165+
)
166+
167+
# Compare model values
168+
torch.manual_seed(self.SEED)
169+
x = m.example_inputs()
170+
x2 = copy.deepcopy(x)
171+
qat_out = qat_model(*x)
172+
ptq_out = ptq_model(*x2)
173+
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
174+
175+
176+
if __name__ == "__main__":
177+
unittest.main()

test/quantization/test_quant_primitives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import unittest
1010
import torch
1111
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
12-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
12+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
1313

1414
class TestQuantPrimitives(unittest.TestCase):
1515
SEED = 123
1616

17-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
17+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
1818
def test_get_group_qparams_symmetric(self):
1919
"""
2020
Test that `get_group_qparams_symmetric` produces the exact same scales as

torchao/quantization/GPTQ.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# LICENSE file in the root directory of this source tree.
1010

1111
import logging
12-
from typing import Optional, List
12+
from typing import Optional, List, Type
1313

1414
import torch
1515

@@ -1120,21 +1120,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
11201120
self.precision,
11211121
)
11221122

1123-
1124-
def replace_linear_8da4w(
1125-
module,
1126-
groupsize,
1127-
padding_allowed,
1128-
precision,
1129-
scales_precision,
1123+
def _replace_linear_8da4w(
1124+
module: torch.nn.Module,
1125+
groupsize: int,
1126+
padding_allowed: bool,
1127+
precision: torch.dtype,
1128+
scales_precision: torch.dtype,
1129+
linear_class: Type[torch.nn.Module],
11301130
):
11311131
for name, child in module.named_children():
11321132
if isinstance(child, nn.Linear):
11331133
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
11341134
setattr(
11351135
module,
11361136
name,
1137-
Int8DynActInt4WeightLinear(
1137+
linear_class(
11381138
child.in_features,
11391139
child.out_features,
11401140
bias=False,
@@ -1144,14 +1144,30 @@ def replace_linear_8da4w(
11441144
),
11451145
)
11461146
else:
1147-
replace_linear_8da4w(
1147+
_replace_linear_8da4w(
11481148
child,
11491149
groupsize,
11501150
padding_allowed,
11511151
precision,
11521152
scales_precision,
11531153
)
11541154

1155+
def replace_linear_8da4w(
1156+
module: torch.nn.Module,
1157+
groupsize: int,
1158+
padding_allowed: bool,
1159+
precision: torch.dtype,
1160+
scales_precision: torch.dtype,
1161+
):
1162+
_replace_linear_8da4w(
1163+
module,
1164+
groupsize,
1165+
padding_allowed,
1166+
precision,
1167+
scales_precision,
1168+
Int8DynActInt4WeightLinear,
1169+
)
1170+
11551171
class Int8DynActInt4WeightQuantizer(Quantizer):
11561172
def __init__(
11571173
self,

torchao/quantization/prototype/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)