Skip to content

Optimize causal mask shape #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 22, 2025
Merged

Optimize causal mask shape #2325

merged 3 commits into from
May 22, 2025

Conversation

gramalingam
Copy link
Collaborator

The generation of the causal mask's shape (produced by the translation of scalar_dot_product_attention) interferes with the subsequent fusion optimizations (because it makes use of the shape of the intermediate matmul value).

This PR introduces a very specific fusion/rewrite to eliminate this redundant computation of the "sequence length" dimension.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link

codecov bot commented May 22, 2025

❌ 8 Tests Failed:

Tests completed Failed Passed Skipped
15997 8 15989 1883
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0271_test_concat_3d_axis_2
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_concat_3d_axis_2'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_concat_3d_axis_2' (e=No module named 'tests.onnx_backend_test_code.test_concat_3d_axis_2') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_concat_3d_axis_2.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_concat_3d_axis_2.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_concat_3d_axis_2(value0: FLOAT[2,2,2], value1: FLOAT[2,2,2]) -> (FLOAT[2,2,4]):
E       output = opset13.Concat(value0, value1, axis=2)
E       return output
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1274_test_unsqueeze_two_axes
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_unsqueeze_two_axes'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_unsqueeze_two_axes' (e=No module named 'tests.onnx_backend_test_code.test_unsqueeze_two_axes') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_unsqueeze_two_axes.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_unsqueeze_two_axes.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_unsqueeze_two_axes(x: FLOAT[3,4,5], axes: INT64[2]) -> (FLOAT[3,1,4,5,1]):
E       y = opset21.Unsqueeze(x, axes)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1139_test_softmax_axis_2_expanded_ver18
Stack Traces | 0.005s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18' (e=No module named 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_softmax_axis_2_expanded_ver18.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_softmax_axis_2_expanded_ver18.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_softmax_axis_2_expanded_ver18(x: FLOAT[3,4,5]) -> (FLOAT[3,4,5]):
E       Softmax_test_softmax_axis_2_expanded_function_axes = opset18.Constant(value=make_tensor("value", 7, dims=[1], vals=[2]))
E       Softmax_test_softmax_axis_2_expanded_function_X_ReduceMax = opset18.ReduceMax(x, Softmax_test_softmax_axis_2_expanded_function_axes, keepdims=1)
E       Softmax_test_softmax_axis_2_expanded_function_X_Sub = opset18.Sub(x, Softmax_test_softmax_axis_2_expanded_function_X_ReduceMax)
E       Softmax_test_softmax_axis_2_expanded_function_X_Exp = opset18.Exp(Softmax_test_softmax_axis_2_expanded_function_X_Sub)
E       Softmax_test_softmax_axis_2_expanded_function_X_ReduceSum = opset18.ReduceSum(Softmax_test_softmax_axis_2_expanded_function_X_Exp, Softmax_test_softmax_axis_2_expanded_function_axes, keepdims=1)
E       y = opset18.Div(Softmax_test_softmax_axis_2_expanded_function_X_Exp, Softmax_test_softmax_axis_2_expanded_function_X_ReduceSum)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

justinchuby pushed a commit that referenced this pull request May 22, 2025
The MHA-Bias rules can be simplified using pattern-disjunction.

(This _may_ help with Whisper ... that was my original motivation, but
not sure, after I fixed another issue in PR #2325, which may be the
primary issue ). But the cleanup is useful anyway, and it makes fusion
more efficient.)

Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So constant folding doesn’t get this properly?

bmehta001 pushed a commit to bmehta001/onnxscript that referenced this pull request May 22, 2025
The MHA-Bias rules can be simplified using pattern-disjunction.

(This _may_ help with Whisper ... that was my original motivation, but
not sure, after I fixed another issue in PR microsoft#2325, which may be the
primary issue ). But the cleanup is useful anyway, and it makes fusion
more efficient.)

Signed-off-by: Ganesan Ramalingam <[email protected]>
@gramalingam
Copy link
Collaborator Author

So constant folding doesn’t get this properly?

Good question (though it is the "optimizer", though we call it constant-folding, since it goes beyond pure constant folding). I think not. It does the necessary analysis for shape-inference. May be worth checking. I thought we might need a more generic optimization pass, but perhaps not.

@gramalingam gramalingam enabled auto-merge (squash) May 22, 2025 18:45
@gramalingam gramalingam merged commit b34cd9c into main May 22, 2025
25 of 29 checks passed
@gramalingam gramalingam deleted the rama/causal-mask-shape-opt branch May 22, 2025 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants