-
Notifications
You must be signed in to change notification settings - Fork 72
Fusion extensions to improve GQA fusion #2374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
❌ 21 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Signed-off-by: Ganesan Ramalingam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces several fusion extensions to improve GQA fusion, including refactoring of key handling in SDPA/MHA, extended support for cos-sin-cache fusion patterns, and a new MaskedGQA operator with causal mask support.
- Rename key_transposed parameter to key and add a new key_format attribute in SDPA fusion.
- Update tests to verify fusion counts with debug flags, and adjust operator usage in GQA fusion.
- Reorder fusion rule invocations and extend cos-sin-cache handling with optional inv_freq expansion.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
onnxscript/rewriter/ort_fusions/sdpa_via_mha.py | Changed parameter name from key_transposed to key and added a key_format attribute to the SDPA operator. |
onnxscript/rewriter/ort_fusions/sdpa.py | Updated key shape checking using key_format and removed key_transposed references. |
onnxscript/rewriter/ort_fusions/mha_test.py | Modified fuse_sdpa invocation to include a debug flag and updated test assertions accordingly. |
onnxscript/rewriter/ort_fusions/mha.py | Removed redundant transpose operations in rotary embedding handling for non-cross-attention cases. |
onnxscript/rewriter/ort_fusions/gqa_test.py | Updated test assertions to use assertGreater for fusion counts. |
onnxscript/rewriter/ort_fusions/gqa.py | Refactored GQA fusion to use MaskedGroupQueryAttention and added a new causal mask rule. |
onnxscript/rewriter/ort_fusions/cos_sin_cache.py | Extended inv_freq handling with an optional expansion and added a TODO for validating expanded_inv_freq shape. |
onnxscript/rewriter/ort_fusions/_core.py | Adjusted fusion rule invocation order by moving gqa fusion outside of the MHA fusion conditional check. |
Comments suppressed due to low confidence (4)
onnxscript/rewriter/ort_fusions/sdpa.py:115
- Using an assert for unexpected key_format values may lead to runtime crashes. Consider handling unsupported key_format cases more gracefully, e.g. by returning a match failure with a clear error message.
if key_format == "BHSd":
onnxscript/rewriter/ort_fusions/gqa.py:243
- There is an inconsistency between using op.MaskedGroupQueryAttention in the pattern and op.GroupQueryAttention in the rewrite with a different domain (_domain). Ensure that the operator selection and domain usage across GQA fusion rules are consistent and intentional.
return op.MaskedGroupQueryAttention(
onnxscript/rewriter/ort_fusions/_core.py:90
- [nitpick] Moving the gqa fusion invocation outside the MHA fusion conditional may lead to overlapping fusion attempts. Verify that gqa fusion is intended to execute independently when MHA fusion is present.
fusion_count["gqa"] = fuse(fuse_gqa)
onnxscript/rewriter/ort_fusions/cos_sin_cache.py:183
- [nitpick] A TODO note indicates that expanded_inv_freq's shape is not fully validated. It is recommended to add explicit shape checks to ensure that the expanded_inv_freq matches the expected dimensions.
if expanded_inv_freq is not None:
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
@justinchuby your approval expired :-( |
Various extensions to improve GQA fusion.