Skip to content

[torchlib] Precompute the constant for gelu to avoid precision loss #2179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025

Conversation

justinchuby
Copy link
Collaborator

I think this improves accuracy for gelu under float16.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Apr 10, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (1)

tests/function_libs/torch_lib/ops_test_data.py:1793

  • The removal of the explicit float16 tolerance for the GELU operator may reduce the test sensitivity to precision improvements. Ensure that the tests still adequately capture potential precision differences under float16.
TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu),

Copy link

codecov bot commented Apr 10, 2025

❌ 4 Tests Failed:

Tests completed Failed Passed Skipped
14101 4 14097 1699
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0735_test_not_3d
Stack Traces | 0.005s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_not_3d'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_not_3d' (e=No module named 'tests.onnx_backend_test_code.test_not_3d') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_not_3d.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_not_3d.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import BOOL
E   from onnxscript.onnx_opset import opset1
E   
E   @script()
E   def bck_test_not_3d(x: BOOL[3,4,5]) -> (BOOL[3,4,5]):
E       r_not = opset1.Not(x)
E       return r_not
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0898_test_reduce_prod_keepdims_example
Stack Traces | 0.005s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_reduce_prod_keepdims_example'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_reduce_prod_keepdims_example' (e=No module named 'tests.onnx_backend_test_code.test_reduce_prod_keepdims_example') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reduce_prod_keepdims_example.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reduce_prod_keepdims_example.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_reduce_prod_keepdims_example(data: FLOAT[3,2,2], axes: INT64[1]) -> (FLOAT[3,1,2]):
E       reduced = opset18.ReduceProd(data, axes, keepdims=1)
E       return reduced
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0656_test_min_two_inputs
Stack Traces | 0.006s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_min_two_inputs'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_min_two_inputs' (e=No module named 'tests.onnx_backend_test_code.test_min_two_inputs') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_min_two_inputs.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_min_two_inputs.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_min_two_inputs(data_0: FLOAT[3], data_1: FLOAT[3]) -> (FLOAT[3]):
E       result = opset13.Min(data_0, data_1)
E       return result

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to keep the op order to match optmization. As long as it still matches, it's fine to genAI: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_gelu.py

@justinchuby
Copy link
Collaborator Author

We have tests right? I can update the pattern

@justinchuby
Copy link
Collaborator Author

OK this should be fine I think. The patterns only concern approximate="none", not "tanh". Can you confirm?

@justinchuby justinchuby enabled auto-merge (squash) April 10, 2025 18:05
@gramalingam
Copy link
Collaborator

Do you know what is the source of the precision improvement? Trying to understand this ... is it because some computation happens at higher precision in python vs. using float32 in onnx? Which one? Would help in other cases if we understood this.

@justinchuby
Copy link
Collaborator Author

The computation happens at higher precision than if we did it in onnx at float16

@titaiwangms
Copy link
Contributor

OK this should be fine I think. The patterns only concern approximate="none", not "tanh". Can you confirm?

tanh

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_fastgelu.py

@justinchuby
Copy link
Collaborator Author

ok thanks. I looked at the patterns. This should still work because the computation is previously done by the constant folder before being pattern matched on. Here we just changed the implementation so the pre-optimized model is accurate.

@justinchuby justinchuby merged commit 005568a into main Apr 10, 2025
23 of 27 checks passed
@justinchuby justinchuby deleted the justinchu/gelu-approx branch April 10, 2025 21:19
@gramalingam
Copy link
Collaborator

The computation happens at higher precision than if we did it in onnx at float16

I see. Because there is a CastLike before the Sqrt, and it should have been after the Sqrt. Got it.

But, agree, I prefer that the computation happen ahead of time ... WRT outdated comment about preferring explicit graph computation over precomputed constants, there is a way of having the cake and eating it too here ... we express the pre-computation explicitly in python, but outside the script-function, clarifying both the computation that happens and when it happens, like below:

sqrt_two_over_pi = math.sqrt(2.0/math.pi)
def _aten_gelu(...):
   ... sqrt_two_over_pi ...

@justinchuby
Copy link
Collaborator Author

The computation happens at higher precision than if we did it in onnx at float16

I see. Because there is a CastLike before the Sqrt, and it should have been after the Sqrt. Got it.

But, agree, I prefer that the computation happen ahead of time ... WRT outdated comment about preferring explicit graph computation over precomputed constants, there is a way of having the cake and eating it too here ... we express the pre-computation explicitly in python, but outside the script-function, clarifying both the computation that happens and when it happens, like below:

sqrt_two_over_pi = math.sqrt(2.0/math.pi)
def _aten_gelu(...):
   ... sqrt_two_over_pi ...

That's a great idea. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

3 participants