Skip to content

Fix numeric mismatches #2085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
@unittest.skip("Currently failing on sqnr")
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
"""
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
Expand All @@ -1493,7 +1492,9 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
torch.manual_seed(seed)
x = m.example_inputs()

quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
quantizer = Int8DynActInt4WeightQATQuantizer(
groupsize=group_size, precision=dtype, scales_precision=dtype
)
prepared = quantizer.prepare(m)
prepared_out = prepared(*x)
converted = quantizer.convert(prepared)
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/quant_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_q_dq_linear_patterns_replacements_and_filters(
glbs["a_quant_min"] = None
glbs["a_quant_max"] = None
glbs["a_mapping_type"] = "ASYMMETRIC"
glbs["a_scale_dtype"] = torch.float64
glbs["a_scale_dtype"] = torch.float32
glbs["a_eps"] = None

lcls = {}
Expand Down
34 changes: 21 additions & 13 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MappingType,
quantize_,
)
from torchao.quantization.utils import compute_error


class TestEmbeddingQuantizer(unittest.TestCase):
Expand Down Expand Up @@ -254,7 +255,7 @@ def test_identical_to_IntxWeightOnlyConfig(
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
Expand Down Expand Up @@ -292,7 +293,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
embedding_filter,
)
expected_out = model(indices)
prepared_out = model(indices)

quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(
Expand All @@ -305,8 +306,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
),
embedding_filter,
)
actual_out = model(indices)
self.assertTrue(torch.allclose(expected_out, actual_out))
converted_out = model(indices)
sqnr = compute_error(prepared_out, converted_out).item()

# For torch.int1, sometimes sqnr is nan because both tensors are all 0
# so we check torch.equal as well
self.assertTrue(
sqnr == float("inf") or torch.equal(prepared_out, converted_out)
)

@parameterized.expand(
[
Expand All @@ -317,7 +324,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
)
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
Expand Down Expand Up @@ -346,7 +353,8 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
zero_point_precision=torch.int32,
)
model = qat_quantizer.prepare(model)
expected_out = model(indices)
prepared_model_copy = copy.deepcopy(model)
prepared_out = model(indices)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
Expand All @@ -360,15 +368,15 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
),
embedding_filter,
)
actual_out1 = model(indices)
self.assertTrue(torch.allclose(expected_out, actual_out1))
converted_out1 = model(indices)
sqnr1 = compute_error(prepared_out, converted_out1).item()
self.assertTrue(sqnr1 == float("inf"))

# TODO: method 2 does not work because the converted embedding op
# incorrectly casts output of to indices.dtype
# Convert model method 2
# qat_quantizer.convert(prepared_model_copy)
# actual_out2 = prepared_model_copy(indices)
# self.assertTrue(torch.allclose(expected_out, actual_out2))
qat_quantizer.convert(prepared_model_copy)
converted_out2 = prepared_model_copy(indices)
sqnr2 = compute_error(prepared_out, converted_out2).item()
self.assertTrue(sqnr2 == float("inf"))


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MappingType,
quantize_,
)
from torchao.quantization.utils import compute_error


class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
Expand Down Expand Up @@ -360,7 +361,7 @@ def test_export_QDQLayout(self):
self.assertTrue(torch.allclose(eager_results, exported_results))

expected_lines = [
"torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float64, torch.int64)",
"torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float32, torch.int8)",
"torch.ops.torchao.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int8)",
"torch.ops.torchao.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int8)",
"torch.ops.torchao.dequantize_affine.default",
Expand Down Expand Up @@ -475,7 +476,8 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
),
)
with torch.no_grad():
torch.allclose(model(activations), model_copy(activations))
sqnr = compute_error(model(activations), model_copy(activations)).item()
self.assertTrue(sqnr == float("inf"))

@parameterized.expand(
[
Expand All @@ -492,7 +494,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
for act_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
Expand All @@ -510,11 +512,6 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
if mapping_type == MappingType.ASYMMETRIC:
return

# TODO: QAT logic for non-float32 models does not match PTQ right now
# QAT's default scale-precision is float32, but PTQ's is None (which defaults to input's dtype)
if model_dtype != torch.float32:
return

assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
assert act_mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
is_symmetric = mapping_type == MappingType.SYMMETRIC
Expand Down Expand Up @@ -550,7 +547,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)
try:
expected_out = model(activations)
prepared_out = model(activations)
except NotImplementedError as e:
# QAT does not support act_mapping_type == MappingType.SYMMETRIC yet
if act_mapping_type == MappingType.SYMMETRIC:
Expand All @@ -568,8 +565,10 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
act_mapping_type=act_mapping_type,
),
)
actual_out = model(activations)
self.assertTrue(torch.allclose(expected_out, actual_out))
converted_out = model(activations)

sqnr = compute_error(prepared_out, converted_out).item()
self.assertTrue(sqnr == float("inf"))

@parameterized.expand(
[
Expand All @@ -580,20 +579,13 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
)
for group_size in [32, 64, 128]
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
for model_dtype in [torch.float32, torch.bfloat16]
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_Int8DynActInt4WeightQATQuantizer(
self, group_size, scale_dtype, model_dtype
):
# Currently this does not match
# TODO: investigat
if scale_dtype != torch.float32:
return
if model_dtype != torch.float32:
return

k0 = 512
k1 = 256
layers = [
Expand All @@ -611,10 +603,10 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
groupsize=group_size, precision=model_dtype, scales_precision=scale_dtype
)
model = qat_quantizer.prepare(model)
expected_out = model(activations)

prepared_model_copy = copy.deepcopy(model)

prepared_out = model(activations)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(
Expand All @@ -627,13 +619,15 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
act_mapping_type=MappingType.ASYMMETRIC,
),
)
actual_out1 = model(activations)
self.assertTrue(torch.allclose(expected_out, actual_out1))
converted_out1 = model(activations)
sqnr1 = compute_error(prepared_out, converted_out1).item()
self.assertTrue(sqnr1 == float("inf"))

# Convert model method 2
qat_quantizer.convert(prepared_model_copy)
actual_out2 = prepared_model_copy(activations)
self.assertTrue(torch.allclose(expected_out, actual_out2))
converted_out2 = prepared_model_copy(activations)
sqnr2 = compute_error(prepared_out, converted_out2).item()
self.assertTrue(sqnr2 == float("inf"))


if __name__ == "__main__":
Expand Down
13 changes: 10 additions & 3 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,16 @@ def linear_forward_8da4w(
zeros,
out_features,
groupsize,
precision,
output_precision,
):
x = per_token_dynamic_quant(x, scale_dtype=precision, zero_point_dtype=precision)
# uses fp32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant
# and activation_scale_dtype in QAT configs
# TODO: in future add ability to specify activation_scale_dtype to PTQ configs
# and enable similar change here
x = per_token_dynamic_quant(
x, scale_dtype=torch.float32, zero_point_dtype=torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like in _int8_asymm_per_token_quant you set the zero point to torch.int8, so this is different? A bit confused here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That difference doesn't actually matter because the zero point domain is int. So the zero points will be int8 numbers (which can be represented in fp32).

But maybe it makes sense to just use quant_api._int8_asymm_per_token_quant directly here, rather than use per_token_dynamic_quant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would prefer us to be consistent in scale and zero point dtype. In fact if we could make scale fp32 and zero point int32 that would be fine. zero point should actually be same as quant dtype

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree that we should just use torch.int8 here for zero points to be consistent, even if it doesn't change the numerics in practice. We probably shouldn't use _int8_asymm_per_token_quant here because that would introduce subclasses into a module swap only flow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to use _int8_asymm_per_token_quant followed be dequantize, e.g., _int8_asymm_per_token_quant(x).dequantize()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_int8_asymm_per_token_quant(x).dequantize()

This would still introduce AQT into this path right? I think I would prefer to keep these separate. IMO just calling the same primitives with unit tests seems enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also creates a circular import. Changing back

)

# TODO: verify and remove following reshape code
# origin_x_size = x.size()
# x = x.reshape(-1, origin_x_size[-1])
Expand All @@ -953,7 +960,7 @@ def linear_forward_8da4w(
torch.int8,
quant_min,
quant_max,
output_dtype=precision,
output_dtype=output_precision,
)

# x = x.to(torch.float16)
Expand Down
26 changes: 19 additions & 7 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
scale_precision=self.scale_precision,
zero_point_precision=self.zero_point_precision,
device=child.weight.device,
dtype=child.weight.dtype,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
Expand Down Expand Up @@ -227,14 +228,19 @@ def _convert_helper(self, module: torch.nn.Module):
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
device=child.weight.device,
output_dtype=child.weight.dtype,
)
setattr(module, name, quantized_embedding)

# Load weights and qparams into quantized embedding
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
(s, zp) = get_group_qparams_symmetric(
child.weight, self.bit_width, group_size
child.weight,
self.bit_width,
group_size,
precision=scale_precision,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to set the zero point precision separately to be more correct? I think the function doesn't let you set them separately, so maybe we need to do an extra cast after calling it here, like zp.to(zero_point_precision)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numeric purposes, it doesn't matter because the zero point is an int8 number, which you can represent with floating point number.

But I'll add the extra cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but it would be better to be consistent to have out zero point dtype separately specified. for example if you were to quantize bias to in32 than you would want zero point in int32 although usually it is just zero.

Besides, if this is for affine quantization than maybe have function name reflect that

)
zp = zp.to(zero_point_precision)
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
child.weight,
s,
Expand Down Expand Up @@ -324,6 +330,7 @@ def __init__(
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
device: torch.device = None,
output_dtype: torch.dtype = torch.float32,
):
super().__init__()

Expand All @@ -341,6 +348,7 @@ def __init__(
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self.output_dtype = output_dtype

# currently storing unpacked int8 weights
self.register_buffer(
Expand All @@ -367,20 +375,24 @@ def __init__(
)

def forward(self, x):
from torchao._executorch_ops import (
_quantized_decomposed_dequantize_per_channel_group_wrapper,
from torchao.quantization.quant_primitives import (
dequantize_affine,
)

qmin, qmax = _get_qmin_qmax(self.bit_width)
w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper(

# dequantize_affine casts to output_dtype before scaling
# dequantize_per_channel_group scales and then casts to output_dtype
# The two do not agree when dtype != torch.float32
w_dq = dequantize_affine(
self.weight,
[1, self.group_size],
self.scale,
self.zero_point,
torch.int8,
qmin,
qmax,
torch.int8,
self.group_size,
x.dtype,
output_dtype=self.output_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should output dtype be the same as x.dtype? Then we don't need a separate arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is an embedding op, so x.dtype is an int (the index). But the weights should not be cast to integers.

This was one of the existing numerical issues in the embedding convert step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks for catching that!

)
return F.embedding(
x,
Expand Down
14 changes: 12 additions & 2 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
n_bit = 4
(qmin, qmax) = _get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(
child.weight, n_bit, config.group_size
child.weight,
n_bit,
config.group_size,
precision=config.scale_precision,
)
zp = zp.to(config.zero_point_precision)
from torchao._executorch_ops import (
_quantized_decomposed_quantize_per_channel_group_wrapper,
)
Expand Down Expand Up @@ -258,6 +262,10 @@ class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
groupsize: the number of elements in each quantized group for weights
precision: precision of weights
scales_precision: precision of per group scales and zero points

Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32).
To get an exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights
and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales.
"""

def __init__(
Expand All @@ -270,7 +278,9 @@ def __init__(
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
activation_config = _get_8da4w_activation_config(scales_precision)
# Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant,
# which is used in PTQ routines
activation_config = _get_8da4w_activation_config(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed? Looks like scales_precision already defaults to torch.float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. We should definitely update the docstring to mention this then, something like

"Note: We hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). To get exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

why is that?

I don't know if there is a great reason for that, it's just how LinearActivationQuantizedTensor works (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24).

Unlike the weights, torchao does not serialize parameters related to activation quantization right now. Instead it serializes a function that does the activation quantization (which is opaque). To support serialization, the function is added to torch.serialization.add_safe_globals.

I think to support specifying information about the activation quantization, we would have to serialize this information in the LinearActivationQuantizedTensor (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24), or register many variants of _int8_asymm_per_token_quant with different dtypes to torch.serialization.add_safe_globals. The former probably makes more sense.

cc @jerryzh168

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the question around we hard coded the args with linear activation quantized tensor? we can use

to support passing around args I think

it allows passing around some args for input_quant_function:

quantized_tensor = input_quant_func(input_tensor, **quant_kwargs)

weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
super().__init__(
in_features,
Expand Down
Loading
Loading