Skip to content

Commit 8bd3c69

Browse files
committed
Match QAT prepare and convert numerics exactly for bf16
**Summary:** The previous PR #1964 got this to match for fp32, but there were two additional sources of numerical discrepancies with bf16: 1. QAT asymmetric per token choose qparams diverged from `choose_qparams_affine`, which had simpler logic 2. QAT per token fake quantize cast the input to fp32 before fake quantizing them These are both resolved in this commit: (1) QAT now uses `choose_qparams_affine` instead of the custom function for asymmetric per token, which is now deleted, and (2) QAT no longer casts the input to fp32. The result is exact match in numerics between the prepare and convert steps for both fp32 and bf16. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_fp32 python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_bf16 python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_fp32 python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_bf16
1 parent 31f119e commit 8bd3c69

File tree

3 files changed

+75
-77
lines changed

3 files changed

+75
-77
lines changed

test/quantization/test_qat.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Int8DynActInt4WeightQATLinear,
4141
)
4242
from torchao.quantization.qat.utils import (
43-
_choose_qparams_per_token_asymmetric,
4443
_fake_quantize_per_channel_group,
4544
_fake_quantize_per_token,
4645
_GenericFakeQuantize,
@@ -53,12 +52,16 @@
5352
MappingType,
5453
TorchAODType,
5554
ZeroPointDomain,
55+
choose_qparams_affine,
56+
dequantize_affine,
5657
fake_quantize_affine,
58+
quantize_affine,
5759
)
5860
from torchao.quantization.unified import (
5961
TwoStepQuantizer,
6062
)
6163
from torchao.quantization.utils import (
64+
_get_per_token_block_size,
6265
get_group_qparams_symmetric,
6366
get_groupwise_affine_qparams,
6467
groupwise_affine_quantize_tensor,
@@ -134,12 +137,13 @@ def forward(self, x):
134137

135138

136139
class M4(torch.nn.Module):
137-
def __init__(self):
140+
def __init__(self, dtype: torch.dtype = torch.float32):
138141
super().__init__()
139-
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
142+
self.dtype = dtype
143+
self.linear = torch.nn.Linear(512, 256, bias=False).to(dtype)
140144

141145
def example_inputs(self):
142-
return (torch.randn(1, 512).to(torch.float),)
146+
return (torch.randn(1, 512).to(self.dtype),)
143147

144148
def forward(self, x):
145149
return self.linear(x)
@@ -219,30 +223,41 @@ def test_fake_quantize_per_token(self):
219223
torch.manual_seed(self.SEED)
220224
x = torch.randn(100, 256).requires_grad_()
221225
x2 = copy.deepcopy(x)
222-
# TODO: use torch.ops.aten.quantized_decomposed version instead
223-
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
226+
block_size = _get_per_token_block_size(x)
227+
(s, zp) = choose_qparams_affine(
228+
x,
229+
mapping_type=MappingType.ASYMMETRIC,
230+
block_size=block_size,
231+
target_dtype=torch.int8,
232+
quant_min=-128,
233+
quant_max=127,
234+
scale_dtype=torch.float32,
235+
zero_point_dtype=torch.int32,
236+
)
224237

225238
# fake quant op
226239
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
227240
out.sum().backward()
228241

229242
# compare against PTQ ops
230-
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
243+
out_ptq = quantize_affine(
231244
x2,
245+
block_size,
232246
s,
233247
zp,
248+
torch.int8,
234249
qmin,
235250
qmax,
236-
torch.int8,
237251
)
238-
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
252+
out_ptq = dequantize_affine(
239253
out_ptq,
254+
block_size,
240255
s,
241256
zp,
257+
torch.int8,
242258
qmin,
243259
qmax,
244-
torch.int8,
245-
torch.float32,
260+
output_dtype=torch.float32,
246261
)
247262
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
248263

@@ -1004,8 +1019,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
10041019
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
10051020
"""
10061021
# activations
1007-
(s, zp) = _choose_qparams_per_token_asymmetric(
1008-
x, torch.float32, torch.int32
1022+
(s, zp) = choose_qparams_affine(
1023+
x,
1024+
mapping_type=MappingType.ASYMMETRIC,
1025+
block_size=_get_per_token_block_size(x),
1026+
target_dtype=torch.int8,
1027+
quant_min=-128,
1028+
quant_max=127,
1029+
scale_dtype=torch.float32,
1030+
zero_point_dtype=torch.int32,
10091031
)
10101032
(qmin, qmax) = _get_qmin_qmax(8)
10111033
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)
@@ -1427,10 +1449,7 @@ def test_qat_linear_bias(self):
14271449
example_inputs = m.example_inputs()
14281450
m(*example_inputs)
14291451

1430-
@unittest.skipIf(
1431-
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1432-
)
1433-
def test_fake_quantize_per_token_vs_convert(self):
1452+
def _test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14341453
"""
14351454
Test that the following produce the exact same numerics:
14361455
1. FakeQuantizer with asymmetric per_token config
@@ -1439,7 +1458,7 @@ def test_fake_quantize_per_token_vs_convert(self):
14391458
from torchao.quantization.utils import per_token_dynamic_quant
14401459

14411460
torch.manual_seed(self.SEED)
1442-
x = torch.randn(1, 235, 2048)
1461+
x = torch.randn(1, 235, 2048).to(dtype)
14431462
config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14441463
fake_quantizer = FakeQuantizer(config)
14451464
fake_quantizer_out = fake_quantizer(x)
@@ -1449,7 +1468,16 @@ def test_fake_quantize_per_token_vs_convert(self):
14491468
@unittest.skipIf(
14501469
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14511470
)
1452-
def test_qat_8da4w_prepare_vs_convert(self):
1471+
def test_fake_quantize_per_token_vs_convert_fp32(self):
1472+
self._test_fake_quantize_per_token_vs_convert(torch.float32)
1473+
1474+
@unittest.skipIf(
1475+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1476+
)
1477+
def test_fake_quantize_per_token_vs_convert_bf16(self):
1478+
self._test_fake_quantize_per_token_vs_convert(torch.bfloat16)
1479+
1480+
def _test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14531481
"""
14541482
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
14551483
numerics that match exactly over N trials.
@@ -1463,7 +1491,7 @@ def test_qat_8da4w_prepare_vs_convert(self):
14631491

14641492
for seed in range(self.SEED, self.SEED + num_trials):
14651493
torch.manual_seed(seed)
1466-
m = M4()
1494+
m = M4(dtype)
14671495
torch.manual_seed(seed)
14681496
x = m.example_inputs()
14691497

@@ -1486,6 +1514,18 @@ def test_qat_8da4w_prepare_vs_convert(self):
14861514
)
14871515
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
14881516

1517+
@unittest.skipIf(
1518+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1519+
)
1520+
def test_qat_8da4w_prepare_vs_convert_fp32(self):
1521+
self._test_qat_8da4w_prepare_vs_convert(torch.float32)
1522+
1523+
@unittest.skipIf(
1524+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1525+
)
1526+
def test_qat_8da4w_prepare_vs_convert_bf16(self):
1527+
self._test_qat_8da4w_prepare_vs_convert(torch.bfloat16)
1528+
14891529

14901530
if __name__ == "__main__":
14911531
unittest.main()

torchao/quantization/qat/fake_quantizer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
from torchao.quantization.quant_primitives import (
1717
_DTYPE_TO_BIT_WIDTH,
1818
_DTYPE_TO_QVALUE_BOUNDS,
19+
MappingType,
20+
choose_qparams_affine,
1921
)
2022
from torchao.quantization.utils import (
23+
_get_per_token_block_size,
2124
get_group_qparams_symmetric,
2225
get_groupwise_affine_qparams,
2326
)
@@ -26,7 +29,6 @@
2629
FakeQuantizeConfig,
2730
)
2831
from .utils import (
29-
_choose_qparams_per_token_asymmetric,
3032
_fake_quantize_per_channel_group,
3133
_fake_quantize_per_token,
3234
)
@@ -69,13 +71,19 @@ def _per_token_forward(self, x: torch.Tensor):
6971
"""
7072
if self.config.is_symmetric:
7173
raise NotImplementedError("Symmetric per token is not supported yet")
74+
75+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7276
if self._should_compute_qparams():
73-
(self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric(
77+
self.scale, self.zero_point = choose_qparams_affine(
7478
x,
75-
self.config.scale_precision,
76-
self.config.zero_point_precision,
79+
mapping_type=MappingType.ASYMMETRIC,
80+
block_size=_get_per_token_block_size(x),
81+
target_dtype=self.config.dtype,
82+
quant_min=qmin,
83+
quant_max=qmax,
84+
scale_dtype=self.config.scale_precision,
85+
zero_point_dtype=self.config.zero_point_precision,
7786
)
78-
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7987
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
8088

8189
def _per_channel_or_group_forward(self, x: torch.Tensor):

torchao/quantization/qat/utils.py

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Tuple
7+
from typing import List
88

99
import torch
1010

@@ -126,9 +126,8 @@ def _fake_quantize_per_token(
126126

127127
_per_token_quant_qparam_dim_check(input, scales, zero_points)
128128
block_size = _get_per_token_block_size(input)
129-
fq_input = input.to(torch.float32)
130129
fq = _GenericFakeQuantize.apply(
131-
fq_input,
130+
input,
132131
block_size,
133132
scales,
134133
zero_points,
@@ -138,55 +137,6 @@ def _fake_quantize_per_token(
138137
return fq.reshape_as(input).to(input.dtype)
139138

140139

141-
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
142-
# The version in pytorch does not have backward support yet so we add
143-
# it here for now until https://github.com/pytorch/pytorch/pull/123452
144-
# is landed.
145-
def _choose_qparams_per_token_asymmetric(
146-
input: torch.Tensor,
147-
scales_precision: torch.dtype = torch.float32,
148-
zero_points_precision: torch.dtype = torch.float32,
149-
) -> Tuple[torch.Tensor, torch.Tensor]:
150-
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
151-
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
152-
every N elements with the same quantization parameter. The dimension for scales/zero_points
153-
will be (M1 * M2 ... * Mn)
154-
155-
Args:
156-
input (torch.Tensor): original float32/float16 Tensor
157-
scales_precision (torch.dtype): precision of returned scales
158-
zero_points_precision (torch.dtype): precision of returned zero points
159-
160-
Returns:
161-
scales and zero_points, both float32 Tensors
162-
"""
163-
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
164-
qmin, qmax = -128, 127
165-
min_val = torch.amin(input, dim=-1, keepdim=True)
166-
max_val = torch.amax(input, dim=-1, keepdim=True)
167-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
168-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
169-
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
170-
171-
# scale
172-
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
173-
scale = scale.clamp(min=eps)
174-
175-
# zero point
176-
descaled_min = min_val_neg / scale
177-
descaled_max = max_val_pos / scale
178-
zero_point_from_min_error = qmin + descaled_min
179-
zero_point_from_max_error = qmax + descaled_max
180-
zero_point = torch.where(
181-
zero_point_from_min_error + zero_point_from_max_error > 0,
182-
qmin - descaled_min,
183-
qmax - descaled_max,
184-
)
185-
zero_point = torch.clamp(zero_point, qmin, qmax).round()
186-
187-
return scale.to(scales_precision), zero_point.to(zero_points_precision)
188-
189-
190140
def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
191141
if symmetric:
192142
qmin = -(2 ** (n_bit - 1))

0 commit comments

Comments
 (0)