Skip to content

Fusion extensions to improve GQA fusion #2374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 13, 2025
Merged

Fusion extensions to improve GQA fusion #2374

merged 10 commits into from
Jun 13, 2025

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented Jun 12, 2025

Various extensions to improve GQA fusion.

  • Move key-transpose into SDPA fusion and clean it up
  • Extend cos-sin-cache fusion to handle a new pattern
  • Reorder GQA and MHA rules
  • Introduce MaskedGQA, since many uses in practice generated GQA with a mask
  • MaskedGQA is subsequently simplified to ORT's GQA if the mask can be verified to be causal.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link

codecov bot commented Jun 12, 2025

❌ 21 Tests Failed:

Tests completed Failed Passed Skipped
16427 21 16406 2359
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1161_test_spacetodepth_example
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_spacetodepth_example'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_spacetodepth_example' (e=No module named 'tests.onnx_backend_test_code.test_spacetodepth_example') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_spacetodepth_example.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_spacetodepth_example.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_spacetodepth_example(x: FLOAT[1,1,4,6]) -> (FLOAT[1,4,2,3]):
E       y = opset13.SpaceToDepth(x, blocksize=2)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0252_test_clip_splitbounds
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_clip_splitbounds'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_clip_splitbounds' (e=No module named 'tests.onnx_backend_test_code.test_clip_splitbounds') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_clip_splitbounds.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_clip_splitbounds.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_clip_splitbounds(x: FLOAT[3], min: FLOAT, max: FLOAT) -> (FLOAT[3]):
E       y = opset13.Clip(x, min, max)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0611_test_maxpool_2d_default
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_maxpool_2d_default'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_maxpool_2d_default' (e=No module named 'tests.onnx_backend_test_code.test_maxpool_2d_default') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_maxpool_2d_default.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_maxpool_2d_default.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset22
E   
E   @script()
E   def bck_test_maxpool_2d_default(x: FLOAT[1,3,32,32]) -> (FLOAT[1,3,31,31]):
E       y = opset22.MaxPool(x, kernel_shape=[2, 2])
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@gramalingam gramalingam changed the title [DRAFT] Fusion extensions to improve GQA fusion Fusion extensions to improve GQA fusion Jun 12, 2025
@gramalingam gramalingam marked this pull request as ready for review June 12, 2025 19:25
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces several fusion extensions to improve GQA fusion, including refactoring of key handling in SDPA/MHA, extended support for cos-sin-cache fusion patterns, and a new MaskedGQA operator with causal mask support.

  • Rename key_transposed parameter to key and add a new key_format attribute in SDPA fusion.
  • Update tests to verify fusion counts with debug flags, and adjust operator usage in GQA fusion.
  • Reorder fusion rule invocations and extend cos-sin-cache handling with optional inv_freq expansion.

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxscript/rewriter/ort_fusions/sdpa_via_mha.py Changed parameter name from key_transposed to key and added a key_format attribute to the SDPA operator.
onnxscript/rewriter/ort_fusions/sdpa.py Updated key shape checking using key_format and removed key_transposed references.
onnxscript/rewriter/ort_fusions/mha_test.py Modified fuse_sdpa invocation to include a debug flag and updated test assertions accordingly.
onnxscript/rewriter/ort_fusions/mha.py Removed redundant transpose operations in rotary embedding handling for non-cross-attention cases.
onnxscript/rewriter/ort_fusions/gqa_test.py Updated test assertions to use assertGreater for fusion counts.
onnxscript/rewriter/ort_fusions/gqa.py Refactored GQA fusion to use MaskedGroupQueryAttention and added a new causal mask rule.
onnxscript/rewriter/ort_fusions/cos_sin_cache.py Extended inv_freq handling with an optional expansion and added a TODO for validating expanded_inv_freq shape.
onnxscript/rewriter/ort_fusions/_core.py Adjusted fusion rule invocation order by moving gqa fusion outside of the MHA fusion conditional check.
Comments suppressed due to low confidence (4)

onnxscript/rewriter/ort_fusions/sdpa.py:115

  • Using an assert for unexpected key_format values may lead to runtime crashes. Consider handling unsupported key_format cases more gracefully, e.g. by returning a match failure with a clear error message.
if key_format == "BHSd":

onnxscript/rewriter/ort_fusions/gqa.py:243

  • There is an inconsistency between using op.MaskedGroupQueryAttention in the pattern and op.GroupQueryAttention in the rewrite with a different domain (_domain). Ensure that the operator selection and domain usage across GQA fusion rules are consistent and intentional.
return op.MaskedGroupQueryAttention(

onnxscript/rewriter/ort_fusions/_core.py:90

  • [nitpick] Moving the gqa fusion invocation outside the MHA fusion conditional may lead to overlapping fusion attempts. Verify that gqa fusion is intended to execute independently when MHA fusion is present.
fusion_count["gqa"] = fuse(fuse_gqa)

onnxscript/rewriter/ort_fusions/cos_sin_cache.py:183

  • [nitpick] A TODO note indicates that expanded_inv_freq's shape is not fully validated. It is recommended to add explicit shape checks to ensure that the expanded_inv_freq matches the expected dimensions.
if expanded_inv_freq is not None:

Signed-off-by: Ganesan Ramalingam <[email protected]>
@gramalingam gramalingam enabled auto-merge (squash) June 13, 2025 19:00
@gramalingam
Copy link
Collaborator Author

@justinchuby your approval expired :-(

@gramalingam gramalingam merged commit 949bc24 into main Jun 13, 2025
28 of 32 checks passed
@gramalingam gramalingam deleted the rama/gqa branch June 13, 2025 22:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants