-
Notifications
You must be signed in to change notification settings - Fork 72
[torchlib] Precompute the constant for gelu to avoid precision loss #2179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (1)
tests/function_libs/torch_lib/ops_test_data.py:1793
- The removal of the explicit float16 tolerance for the GELU operator may reduce the test sensitivity to precision improvements. Ensure that the tests still adequately capture potential precision differences under float16.
TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu),
❌ 4 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's important to keep the op order to match optmization. As long as it still matches, it's fine to genAI: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_gelu.py
We have tests right? I can update the pattern |
OK this should be fine I think. The patterns only concern approximate="none", not "tanh". Can you confirm? |
Do you know what is the source of the precision improvement? Trying to understand this ... is it because some computation happens at higher precision in python vs. using float32 in onnx? Which one? Would help in other cases if we understood this. |
The computation happens at higher precision than if we did it in onnx at float16 |
tanh |
ok thanks. I looked at the patterns. This should still work because the computation is previously done by the constant folder before being pattern matched on. Here we just changed the implementation so the pre-optimized model is accurate. |
I see. Because there is a CastLike before the Sqrt, and it should have been after the Sqrt. Got it. But, agree, I prefer that the computation happen ahead of time ... WRT outdated comment about preferring explicit graph computation over precomputed constants, there is a way of having the cake and eating it too here ... we express the pre-computation explicitly in python, but outside the script-function, clarifying both the computation that happens and when it happens, like below: sqrt_two_over_pi = math.sqrt(2.0/math.pi)
def _aten_gelu(...):
... sqrt_two_over_pi ... |
That's a great idea. Thanks! |
I think this improves accuracy for gelu under float16.