Skip to content

GQA Fusion #2161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
88adde4
A couple of MHA extensions
gramalingam Mar 13, 2025
c4c1f71
Run lint
gramalingam Mar 13, 2025
0425174
Minor fixes
gramalingam Mar 13, 2025
4b1a68d
Update GQA
gramalingam Mar 14, 2025
4dbb44c
Minor fixes
gramalingam Mar 14, 2025
248a113
Merge with main
gramalingam Mar 15, 2025
4d73c6e
Switch to new GQA
gramalingam Mar 15, 2025
9c79b98
Fix variable naming
gramalingam Mar 15, 2025
0d0e8ae
Add num heads attributes
gramalingam Mar 17, 2025
b588582
Use seqlens and totalseqlen
gramalingam Mar 18, 2025
3febda2
Add cos and sin cache
gramalingam Mar 18, 2025
afcf0a7
Fix int32 type
gramalingam Mar 19, 2025
03c08c7
GQA fusion
gramalingam Mar 20, 2025
9040172
Merge branch 'rama/GQA2' of https://github.com/microsoft/onnx-script …
gramalingam Mar 20, 2025
0bc603f
Basic GQA test
gramalingam Mar 25, 2025
794f0dd
Minor refactoring
gramalingam Mar 27, 2025
a7ba01b
Switch to script
gramalingam Mar 27, 2025
d68b0e7
Add blank line
gramalingam Mar 27, 2025
e535207
Merge branch 'main' into rama/gqa-basic-test
gramalingam Mar 27, 2025
df5c69d
Add test case with past and rotary
gramalingam Mar 28, 2025
045fc6f
Add new test
gramalingam Mar 29, 2025
edf289f
Cleanup test case
gramalingam Mar 29, 2025
457ac32
Added test with past and rotary
gramalingam Mar 29, 2025
1efdb26
Remove debug print
gramalingam Mar 29, 2025
13a71c0
Minor cleanup
gramalingam Mar 30, 2025
e74000e
Merge with main
gramalingam Mar 31, 2025
e3dadc9
Add causal mask pattern
gramalingam Apr 1, 2025
ff72ac3
Merge branch 'rama/gqa-basic-test' into rama/GQA2
gramalingam Apr 2, 2025
b7a1398
Add test case
gramalingam Apr 2, 2025
3bdc0b2
Complete GQA tests
gramalingam Apr 2, 2025
9545e5f
Cleanup
gramalingam Apr 2, 2025
d19367a
Merge branch 'main' into rama/GQA2
gramalingam Apr 2, 2025
97bb1c2
Address copilot fixes
gramalingam Apr 2, 2025
56adee8
Merge branch 'rama/GQA2' of https://github.com/microsoft/onnx-script …
gramalingam Apr 2, 2025
c8bbb02
Add checks
gramalingam Apr 3, 2025
43b3368
Minor cleanup
gramalingam Apr 3, 2025
78f9243
Merge with main and address PR comments
gramalingam Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,17 +405,13 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
shape = _get_input(node, 1)
if input is None or shape is None:
return None

input_shape = input.shape
if input_shape is None:
return None
# input_shape_dims = list(input_shape.dims)
# if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims):
# return None
shape_value = state.get_shape_value(shape)
if shape_value is None:

if shape_value is None or input_shape is None:
return None
# target_shape_dims = list(shape_value.dims)
# if input_shape_dims == target_shape_dims:

# No need to check for special values like -1, 0, etc. here
if _same_shape(input_shape, shape_value):
return op.Identity(input)
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/rewriter/ort_fusions/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,8 @@ def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol)
except AssertionError as e:
diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol)
diff = np.where(diff_mask, "X", " ")
print(diff)
print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}")
raise
Loading
Loading