-
Notifications
You must be signed in to change notification settings - Fork 292
Match QAT prepare and convert numerics exactly for bf16 and fp16 #2060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2060
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0267d18 with merge base 31f119e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4c2da01
to
8bd3c69
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
test/quantization/test_qat.py
Outdated
@unittest.skipIf( | ||
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" | ||
) | ||
def test_fake_quantize_per_token_vs_convert_bf16(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we have float16 as well? also can probably use parametrization for these
8bd3c69
to
26df223
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
**Summary:** The previous PR #1964 got this to match for fp32, but there were two additional sources of numerical discrepancies with bf16: 1. QAT asymmetric per token choose qparams diverged from `choose_qparams_affine`, which had simpler logic 2. QAT per token fake quantize cast the input to fp32 before fake quantizing them 3. QAT symmetric per group choose qparams used a hardcoded eps value that did not match `choose_qparams_affine` These are both resolved in this commit: (1) QAT now uses `choose_qparams_affine` instead of the custom function for asymmetric per token, which is now deleted, (2) QAT no longer casts the input to fp32, and (3) QAT now uses an eps value that corresponds to the input dtype. The result is exact match in numerics between the prepare and convert steps for both fp32, bf16, and fp16. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
26df223
to
0267d18
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
seems like some relevant tests are failing in CI? https://github.com/pytorch/ao/actions/runs/14568001515/job/40860205831 |
I am seeing |
Sorry, let me revert this for now |
Summary: The previous PR #1964 got this to match for fp32, but there were three additional sources of numerical discrepancies with bf16:
choose_qparams_affine
, which had simpler logicchoose_qparams_affine
These are both resolved in this commit: (1) QAT now uses
choose_qparams_affine
instead of the custom function for asymmetric per token, which is now deleted, (2) QAT no longer casts the input to fp32, and (3) QAT now uses an eps value that corresponds to the input dtype. The result is exact match in numerics between the prepare and convert steps for both fp32, bf16, and fp16.Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert