Skip to content

Fix numeric mismatches #2085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025
Merged

Fix numeric mismatches #2085

merged 13 commits into from
Apr 25, 2025

Conversation

metascroy
Copy link
Contributor

For embedding, these changes ensure:

  • IntXQuantizationAwareTrainingConfig
  • IntxWeightOnlyConfig
  • Int4WeightOnlyEmbeddingQATQuantizer (prepare / convert)

match for various model/scale dtypes (tested in torchao/experimental/tests/test_embedding_xbit_quantizer.py).

For linear, these changes ensures:

  • Int8DynamicActivationIntxWeightConfig
  • Int8DynamicActivationInt4WeightConfig
  • IntXQuantizationAwareTrainingConfig
  • Int8DynActInt4WeightQATQuantizer (prepare / convert)

match for various model/scale dtypes (tested in torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py).

The exception is the PTQ configs are very different than QAT configs for asymmetric weights (not default setting). Making these match will require discussion.

dynamic activation quantization has several similar, but slightly different functions used throughout torchao (they generally agree when dtype is torch.float32), but they should probably be unified:

  • torchao.utils.per_token_dynamic_quant (used in Int8DynActInt4WeightQATQuantizer's convert)
  • torchao.quantization.qat.utils._choose_qparams_per_token_asymmetric (used in QAT configs like IntXQuantizationAwareTrainingConfig)
  • torchao.quantization.quant_api._int8_asymm_per_token_quant (used in PTQ configs like Int8DynamicActivationInt4WeightConfig and Int8DynamicActivationIntxWeightConfig)

Copy link

pytorch-bot bot commented Apr 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2085

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5eab1cf with merge base cdced21 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@metascroy metascroy requested a review from andrewor14 April 21, 2025 04:24
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2025
@andrewor14
Copy link
Contributor

Hi @metascroy, these should already be fixed in #2060. Can you verify?

@metascroy
Copy link
Contributor Author

Hi @metascroy, these should already be fixed in #2060. Can you verify?

I think the bugs in the embedding quantizer (casting output to int instead of float) are still there, but let me verify the other one.

@metascroy metascroy force-pushed the fix-numeric-issues branch from 8b82926 to 3015548 Compare April 21, 2025 16:58
@metascroy
Copy link
Contributor Author

Hi @metascroy, these should already be fixed in #2060. Can you verify?

I think the bugs in the embedding quantizer (casting output to int instead of float) are still there, but let me verify the other one.

I tried rebasing on your fixes, and they don't resolve the embedding issue and there is still an issue with linear. The problem is inconsistency in scale_dtype for activations. In most places it is fp32, but in some places it is set to scale_dtype from weights.

@metascroy metascroy added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 21, 2025
@@ -568,14 +568,16 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
"""This is defined here instead of local function to support serialization"""
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
scale_dtype = torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these ops are backed by ops in et, than this changes maybe non-trivial. In teneral I think the change makes sense but need to be coordinated for BC.

Second does this change align with existing PT2E APIs in pytorch core? Else we again have divergences.

Third: while we are makign this change I want to open up the possibility of zero_point_dtype being tied to quantization type for affine quantization. Again BC breaking change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used in PT2E code.

On ET side, XNNPACK does not look at the scale dtype when detecting a lowering pattern, although need to make sure it still suceeds on ET pin bump: https://github.com/pytorch/executorch/blob/main/backends/xnnpack/utils/quant_utils.py#L195

I can make the zero_point_dtype = torch.int8. The danger there is if people don't use torchao's dequant API, they can overflow by subtracting two int8s.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed zero point to int8

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used in PT2E code.

On ET side, XNNPACK does not look at the scale dtype when detecting a lowering pattern, although need to make sure it still suceeds on ET pin bump: https://github.com/pytorch/executorch/blob/main/backends/xnnpack/utils/quant_utils.py#L195

I can make the zero_point_dtype = torch.int8. The danger there is if people don't use torchao's dequant API, they can overflow by subtracting two int8s.

Will the flow every produce et model with ops which belong here https://github.com/pytorch/executorch/blob/main/kernels/quantized/quantized.yaml? If not thats probably fine.

Separately, I imagine we will have to add support for these ops for quantized kv cache as well. Right now it is using the older ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, those ops are all the old quant primitives.

This will show up in ET as a new quant primitive in the torchao namespace. Currently only XNNPACK can recognize it.

@metascroy metascroy force-pushed the fix-numeric-issues branch from da642dd to 00858b6 Compare April 21, 2025 20:51
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy force-pushed the fix-numeric-issues branch from 00858b6 to 60eca6e Compare April 21, 2025 22:35
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

2 similar comments
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy force-pushed the fix-numeric-issues branch from 60eca6e to 850be5a Compare April 22, 2025 19:10
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy force-pushed the fix-numeric-issues branch from 850be5a to 9a6b432 Compare April 23, 2025 03:41
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy force-pushed the fix-numeric-issues branch from 8620147 to f5cb0ce Compare April 23, 2025 18:21
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@metascroy metascroy force-pushed the fix-numeric-issues branch from f5cb0ce to b11f92b Compare April 24, 2025 20:27
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

torch.int8,
self.group_size,
x.dtype,
output_dtype=self.output_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should output dtype be the same as x.dtype? Then we don't need a separate arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is an embedding op, so x.dtype is an int (the index). But the weights should not be cast to integers.

This was one of the existing numerical issues in the embedding convert step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks for catching that!

child.weight,
self.bit_width,
group_size,
precision=scale_precision,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to set the zero point precision separately to be more correct? I think the function doesn't let you set them separately, so maybe we need to do an extra cast after calling it here, like zp.to(zero_point_precision)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numeric purposes, it doesn't matter because the zero point is an int8 number, which you can represent with floating point number.

But I'll add the extra cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but it would be better to be consistent to have out zero point dtype separately specified. for example if you were to quantize bias to in32 than you would want zero point in int32 although usually it is just zero.

Besides, if this is for affine quantization than maybe have function name reflect that

@@ -933,7 +933,11 @@ def linear_forward_8da4w(
groupsize,
precision,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename this to output_precision then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could

@@ -270,7 +273,7 @@ def __init__(
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
activation_config = _get_8da4w_activation_config(scales_precision)
activation_config = _get_8da4w_activation_config(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed? Looks like scales_precision already defaults to torch.float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. We should definitely update the docstring to mention this then, something like

"Note: We hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). To get exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.

why is that?

I don't know if there is a great reason for that, it's just how LinearActivationQuantizedTensor works (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24).

Unlike the weights, torchao does not serialize parameters related to activation quantization right now. Instead it serializes a function that does the activation quantization (which is opaque). To support serialization, the function is added to torch.serialization.add_safe_globals.

I think to support specifying information about the activation quantization, we would have to serialize this information in the LinearActivationQuantizedTensor (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24), or register many variants of _int8_asymm_per_token_quant with different dtypes to torch.serialization.add_safe_globals. The former probably makes more sense.

cc @jerryzh168

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the question around we hard coded the args with linear activation quantized tensor? we can use

to support passing around args I think

it allows passing around some args for input_quant_function:

quantized_tensor = input_quant_func(input_tensor, **quant_kwargs)

andrewor14 added a commit that referenced this pull request Apr 25, 2025
Following @metascroy's investigation in #2085, we can unskip this
test, which was caused by activation scales having different
precisions between prepare and convert.

**Test Plan:**
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
@metascroy metascroy force-pushed the fix-numeric-issues branch from b11f92b to 718e650 Compare April 25, 2025 21:07
@andrewor14
Copy link
Contributor

Hi @metascroy, another thing is I think we can unskip the failing test test_qat_8da4w_prepare_vs_convert after this PR. I tried this out here: #2131. Do you want to incorporate the changes there into this PR?

@metascroy
Copy link
Contributor Author

Hi @metascroy, another thing is I think we can unskip the failing test test_qat_8da4w_prepare_vs_convert after this PR. I tried this out here: #2131. Do you want to incorporate the changes there into this PR?

Added your changes to this PR

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

andrewor14 added a commit that referenced this pull request Apr 25, 2025
Following @metascroy's investigation in #2085, we can unskip this
test, which was caused by activation scales having different
precisions between prepare and convert.

**Test Plan:**
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
andrewor14 added a commit that referenced this pull request Apr 25, 2025
Following @metascroy's investigation in #2085, we can unskip this
test, which was caused by activation scales having different
precisions between prepare and convert.

**Test Plan:**
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
@andrewor14
Copy link
Contributor

Looks great, thanks for fixing!

@metascroy
Copy link
Contributor Author

metascroy commented Apr 25, 2025

@andrewor14 @jerryzh168 I cannot import this to Phabricator (D73691281). I've rebased on main multiple times, and always run into the same issue.

Do you know what could be going on?

@jerryzh168
Copy link
Contributor

most likely it's because some inconsistencies between OSS and phabricator code, you can try again after we land the recent diff train diffs I think

@metascroy
Copy link
Contributor Author

most likely it's because some inconsistencies between OSS and phabricator code, you can try again after we land the recent diff train diffs I think

I guess I'll land in OSS and go on top of the diff train then

@metascroy metascroy merged commit 58502e3 into main Apr 25, 2025
18 of 19 checks passed
@kimishpatel
Copy link
Contributor

I think what we can do to make it clearer is add additional documentation (maybe to the docstring of Int8DynActInt4WeightQATLinear) to clarify what is needed to make the numerics match end-to-end.

If this is referring to matching numerics e2e with QAT numerics post conver, then my mental model is that user should not have to do anything, no?

andrewor14 added a commit that referenced this pull request May 7, 2025
**Summary:** This commit does two things:

(1) Allow users to set eps in `FakeQuantizeConfig`
(2) For other parts of the QAT flow, set eps to
    `torch.finfo(torch.float32).eps` for input linear
    activations to match the existing hardcoded input
    activation scale dtype (which is fp32)

The motivation is to enable users who wish to lower their
models to XNNPACK. This would require them to use the following
combination of dtypes during training for end-to-end numerical
match:

- input activations: bf16
- input activation scales: fp32
- input activation eps: `torch.finfo(torch.float32).eps`
- weight: bf16
- weight scales: bf16
- weight eps: `torch.finfo(torch.bfloat16).eps`

However, today there is no way to specify the above in any
of the QAT flows. For the recommended `FakeQuantizeConfig`
flow, we always use `torch.finfo(x.dtype).eps`, where x is
bf16 in this case, and there is no way for users to configure
this. This is resolved by (1).

For the legacy `Int8DynActInt4QATQuantizer` flow, we hardcode
input activation scales to always use fp32 in
#2085, but did not set the
corresponding eps. Today, this also uses `torch.finfo(x.dtype).eps`
by default, where x is bf16, and so we use the wrong eps value.
This is resolved by (2).

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config_eps
python test/quantization/test_qat.py -k test_qat_8da4w_eps
andrewor14 added a commit that referenced this pull request May 9, 2025
* Set eps in end-to-end QAT flow

**Summary:** This commit does two things:

(1) Allow users to set eps in `FakeQuantizeConfig`
(2) For other parts of the QAT flow, set eps to
    `torch.finfo(torch.float32).eps` for input linear
    activations to match the existing hardcoded input
    activation scale dtype (which is fp32)

The motivation is to enable users who wish to lower their
models to XNNPACK. This would require them to use the following
combination of dtypes during training for end-to-end numerical
match:

- input activations: bf16
- input activation scales: fp32
- input activation eps: `torch.finfo(torch.float32).eps`
- weight: bf16
- weight scales: bf16
- weight eps: `torch.finfo(torch.bfloat16).eps`

However, today there is no way to specify the above in any
of the QAT flows. For the recommended `FakeQuantizeConfig`
flow, we always use `torch.finfo(x.dtype).eps`, where x is
bf16 in this case, and there is no way for users to configure
this. This is resolved by (1).

For the legacy `Int8DynActInt4QATQuantizer` flow, we hardcode
input activation scales to always use fp32 in
#2085, but did not set the
corresponding eps. Today, this also uses `torch.finfo(x.dtype).eps`
by default, where x is bf16, and so we use the wrong eps value.
This is resolved by (2).

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config_eps
python test/quantization/test_qat.py -k test_qat_8da4w_eps

* up

---------

Co-authored-by: Scott Roy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants