-
Notifications
You must be signed in to change notification settings - Fork 289
Fix numeric mismatches #2085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix numeric mismatches #2085
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2085
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5eab1cf with merge base cdced21 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @metascroy, these should already be fixed in #2060. Can you verify? |
I think the bugs in the embedding quantizer (casting output to int instead of float) are still there, but let me verify the other one. |
8b82926
to
3015548
Compare
I tried rebasing on your fixes, and they don't resolve the embedding issue and there is still an issue with linear. The problem is inconsistency in scale_dtype for activations. In most places it is fp32, but in some places it is set to scale_dtype from weights. |
@@ -568,14 +568,16 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: | |||
"""This is defined here instead of local function to support serialization""" | |||
mapping_type = MappingType.ASYMMETRIC | |||
target_dtype = torch.int8 | |||
scale_dtype = torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if these ops are backed by ops in et, than this changes maybe non-trivial. In teneral I think the change makes sense but need to be coordinated for BC.
Second does this change align with existing PT2E APIs in pytorch core? Else we again have divergences.
Third: while we are makign this change I want to open up the possibility of zero_point_dtype being tied to quantization type for affine quantization. Again BC breaking change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used in PT2E code.
On ET side, XNNPACK does not look at the scale dtype when detecting a lowering pattern, although need to make sure it still suceeds on ET pin bump: https://github.com/pytorch/executorch/blob/main/backends/xnnpack/utils/quant_utils.py#L195
I can make the zero_point_dtype = torch.int8. The danger there is if people don't use torchao's dequant API, they can overflow by subtracting two int8s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed zero point to int8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used in PT2E code.
On ET side, XNNPACK does not look at the scale dtype when detecting a lowering pattern, although need to make sure it still suceeds on ET pin bump: https://github.com/pytorch/executorch/blob/main/backends/xnnpack/utils/quant_utils.py#L195
I can make the zero_point_dtype = torch.int8. The danger there is if people don't use torchao's dequant API, they can overflow by subtracting two int8s.
Will the flow every produce et model with ops which belong here https://github.com/pytorch/executorch/blob/main/kernels/quantized/quantized.yaml? If not thats probably fine.
Separately, I imagine we will have to add support for these ops for quantized kv cache as well. Right now it is using the older ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, those ops are all the old quant primitives.
This will show up in ET as a new quant primitive in the torchao namespace. Currently only XNNPACK can recognize it.
da642dd
to
00858b6
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
00858b6
to
60eca6e
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2 similar comments
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
60eca6e
to
850be5a
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
850be5a
to
9a6b432
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
8620147
to
f5cb0ce
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
f5cb0ce
to
b11f92b
Compare
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch.int8, | ||
self.group_size, | ||
x.dtype, | ||
output_dtype=self.output_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should output dtype be the same as x.dtype
? Then we don't need a separate arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is an embedding op, so x.dtype is an int (the index). But the weights should not be cast to integers.
This was one of the existing numerical issues in the embedding convert step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, thanks for catching that!
child.weight, | ||
self.bit_width, | ||
group_size, | ||
precision=scale_precision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably need to set the zero point precision separately to be more correct? I think the function doesn't let you set them separately, so maybe we need to do an extra cast after calling it here, like zp.to(zero_point_precision)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For numeric purposes, it doesn't matter because the zero point is an int8 number, which you can represent with floating point number.
But I'll add the extra cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah but it would be better to be consistent to have out zero point dtype separately specified. for example if you were to quantize bias to in32 than you would want zero point in int32 although usually it is just zero.
Besides, if this is for affine quantization than maybe have function name reflect that
torchao/quantization/GPTQ.py
Outdated
@@ -933,7 +933,11 @@ def linear_forward_8da4w( | |||
groupsize, | |||
precision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we rename this to output_precision
then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could
@@ -270,7 +273,7 @@ def __init__( | |||
precision: torch.dtype = torch.float32, | |||
scales_precision: torch.dtype = torch.float32, | |||
) -> None: | |||
activation_config = _get_8da4w_activation_config(scales_precision) | |||
activation_config = _get_8da4w_activation_config(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed? Looks like scales_precision
already defaults to torch.float32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.
why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense. We should definitely update the docstring to mention this then, something like
"Note: We hardcode activation scales to use torch.fp32
, but allow users to specify the weight scales (defaults to torch.fp32
). To get exact numerical match with Int8DynamicActivationInt4WeightConfig
, users must use the same dtype for both the weights and the scales. Here scales_precision
refers specifically to the weight scales only, not the activation scales."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the PTQ routines, scale_precision on the dynamic activation quantization is always FP64 (see _int8_asymm_per_token_quant, https://fburl.com/uf90caon; in this diff I change it to FP32, though.), even if the scale_dtype on the weight or the input_dtype are set to something else.
why is that?
I don't know if there is a great reason for that, it's just how LinearActivationQuantizedTensor works (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24).
Unlike the weights, torchao does not serialize parameters related to activation quantization right now. Instead it serializes a function that does the activation quantization (which is opaque). To support serialization, the function is added to torch.serialization.add_safe_globals.
I think to support specifying information about the activation quantization, we would have to serialize this information in the LinearActivationQuantizedTensor (https://github.com/pytorch/ao/blob/main/torchao/quantization/linear_activation_quantized_tensor.py#L24), or register many variants of _int8_asymm_per_token_quant with different dtypes to torch.serialization.add_safe_globals. The former probably makes more sense.
cc @jerryzh168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the question around we hard coded the args with linear activation quantized tensor? we can use
quant_kwargs: Dict[str, Any], |
it allows passing around some args for input_quant_function:
quantized_tensor = input_quant_func(input_tensor, **quant_kwargs) |
Following @metascroy's investigation in #2085, we can unskip this test, which was caused by activation scales having different precisions between prepare and convert. **Test Plan:** python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
b11f92b
to
718e650
Compare
Hi @metascroy, another thing is I think we can unskip the failing test |
Added your changes to this PR |
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Following @metascroy's investigation in #2085, we can unskip this test, which was caused by activation scales having different precisions between prepare and convert. **Test Plan:** python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
Following @metascroy's investigation in #2085, we can unskip this test, which was caused by activation scales having different precisions between prepare and convert. **Test Plan:** python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
Looks great, thanks for fixing! |
@andrewor14 @jerryzh168 I cannot import this to Phabricator (D73691281). I've rebased on main multiple times, and always run into the same issue. Do you know what could be going on? |
most likely it's because some inconsistencies between OSS and phabricator code, you can try again after we land the recent diff train diffs I think |
I guess I'll land in OSS and go on top of the diff train then |
If this is referring to matching numerics e2e with QAT numerics post conver, then my mental model is that user should not have to do anything, no? |
**Summary:** This commit does two things: (1) Allow users to set eps in `FakeQuantizeConfig` (2) For other parts of the QAT flow, set eps to `torch.finfo(torch.float32).eps` for input linear activations to match the existing hardcoded input activation scale dtype (which is fp32) The motivation is to enable users who wish to lower their models to XNNPACK. This would require them to use the following combination of dtypes during training for end-to-end numerical match: - input activations: bf16 - input activation scales: fp32 - input activation eps: `torch.finfo(torch.float32).eps` - weight: bf16 - weight scales: bf16 - weight eps: `torch.finfo(torch.bfloat16).eps` However, today there is no way to specify the above in any of the QAT flows. For the recommended `FakeQuantizeConfig` flow, we always use `torch.finfo(x.dtype).eps`, where x is bf16 in this case, and there is no way for users to configure this. This is resolved by (1). For the legacy `Int8DynActInt4QATQuantizer` flow, we hardcode input activation scales to always use fp32 in #2085, but did not set the corresponding eps. Today, this also uses `torch.finfo(x.dtype).eps` by default, where x is bf16, and so we use the wrong eps value. This is resolved by (2). **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_eps python test/quantization/test_qat.py -k test_qat_8da4w_eps
* Set eps in end-to-end QAT flow **Summary:** This commit does two things: (1) Allow users to set eps in `FakeQuantizeConfig` (2) For other parts of the QAT flow, set eps to `torch.finfo(torch.float32).eps` for input linear activations to match the existing hardcoded input activation scale dtype (which is fp32) The motivation is to enable users who wish to lower their models to XNNPACK. This would require them to use the following combination of dtypes during training for end-to-end numerical match: - input activations: bf16 - input activation scales: fp32 - input activation eps: `torch.finfo(torch.float32).eps` - weight: bf16 - weight scales: bf16 - weight eps: `torch.finfo(torch.bfloat16).eps` However, today there is no way to specify the above in any of the QAT flows. For the recommended `FakeQuantizeConfig` flow, we always use `torch.finfo(x.dtype).eps`, where x is bf16 in this case, and there is no way for users to configure this. This is resolved by (1). For the legacy `Int8DynActInt4QATQuantizer` flow, we hardcode input activation scales to always use fp32 in #2085, but did not set the corresponding eps. Today, this also uses `torch.finfo(x.dtype).eps` by default, where x is bf16, and so we use the wrong eps value. This is resolved by (2). **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_eps python test/quantization/test_qat.py -k test_qat_8da4w_eps * up --------- Co-authored-by: Scott Roy <[email protected]>
For embedding, these changes ensure:
match for various model/scale dtypes (tested in torchao/experimental/tests/test_embedding_xbit_quantizer.py).
For linear, these changes ensures:
match for various model/scale dtypes (tested in torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py).
The exception is the PTQ configs are very different than QAT configs for asymmetric weights (not default setting). Making these match will require discussion.
dynamic activation quantization has several similar, but slightly different functions used throughout torchao (they generally agree when dtype is torch.float32), but they should probably be unified: