Skip to content

Commit 016625c

Browse files
committed
Unskip test_qat_8da4w_prepare_vs_convert
Following @metascroy's investigation in #2085, we can unskip this test, which was caused by activation scales having different precisions between prepare and convert. **Test Plan:** python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
1 parent cdced21 commit 016625c

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

test/quantization/test_qat.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14741474
@unittest.skipIf(
14751475
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
14761476
)
1477-
@unittest.skip("Currently failing on sqnr")
14781477
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14791478
"""
14801479
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1493,7 +1492,11 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14931492
torch.manual_seed(seed)
14941493
x = m.example_inputs()
14951494

1496-
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1495+
quantizer = Int8DynActInt4WeightQATQuantizer(
1496+
groupsize=group_size,
1497+
precision=dtype,
1498+
scales_precision=dtype,
1499+
)
14971500
prepared = quantizer.prepare(m)
14981501
prepared_out = prepared(*x)
14991502
converted = quantizer.convert(prepared)

torchao/quantization/GPTQ.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,11 @@ def linear_forward_8da4w(
933933
groupsize,
934934
precision,
935935
):
936-
x = per_token_dynamic_quant(x, scale_dtype=precision, zero_point_dtype=precision)
936+
x = per_token_dynamic_quant(
937+
x,
938+
scale_dtype=torch.float32,
939+
zero_point_dtype=torch.int8,
940+
)
937941
# TODO: verify and remove following reshape code
938942
# origin_x_size = x.size()
939943
# x = x.reshape(-1, origin_x_size[-1])

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,9 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
219219
n_bit = 4
220220
(qmin, qmax) = _get_qmin_qmax(n_bit)
221221
(s, zp) = get_group_qparams_symmetric(
222-
child.weight, n_bit, config.group_size
222+
child.weight, n_bit, config.group_size, config.scale_precision,
223223
)
224+
zp = zp.to(config.zero_point_precision)
224225
from torchao._executorch_ops import (
225226
_quantized_decomposed_quantize_per_channel_group_wrapper,
226227
)
@@ -270,7 +271,7 @@ def __init__(
270271
precision: torch.dtype = torch.float32,
271272
scales_precision: torch.dtype = torch.float32,
272273
) -> None:
273-
activation_config = _get_8da4w_activation_config(scales_precision)
274+
activation_config = _get_8da4w_activation_config(torch.float32)
274275
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
275276
super().__init__(
276277
in_features,

0 commit comments

Comments
 (0)