Skip to content

SDPA fusion cleanup #2352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5374b3d
Partial fixes to SDPA
gramalingam May 28, 2025
ac2d91d
Use disjunction in SDPA fusion
gramalingam May 28, 2025
e9fefba
Remove rank check
gramalingam May 28, 2025
01f7b21
Remove debug
gramalingam May 28, 2025
151244a
Add type annotations
gramalingam May 28, 2025
d77c4a6
Merge branch 'main' into rama/sdpa
gramalingam May 28, 2025
366e167
Add missing shapes to gqa_test
gramalingam May 29, 2025
f32b823
Cleanup match failure handling
gramalingam May 29, 2025
938e5d0
Add value to match error
gramalingam May 29, 2025
1afaa16
Merge branch 'main' into rama/sdpa
gramalingam May 29, 2025
d0eb6fa
Cleanup duplicated code
gramalingam May 29, 2025
375adff
Merge branch 'rama/sdpa' of https://github.com/microsoft/onnx-script …
gramalingam May 29, 2025
b145017
Remove outdated comment
gramalingam May 29, 2025
c5ae7f4
Fix renaming
gramalingam May 29, 2025
2cd57fb
Update onnxscript/rewriter/_fusion_utils.py
gramalingam May 31, 2025
ae4ff4a
Potential fix for code scanning alert no. 17154: Unused local variable
gramalingam May 31, 2025
ba62323
Fix handling of unknown headsize
gramalingam May 31, 2025
dc751fb
Merge branch 'main' into rama/sdpa
gramalingam May 31, 2025
2a46711
Update onnxscript/rewriter/ort_fusions/sdpa.py
gramalingam Jun 2, 2025
5792af3
Add negative SDPA test case
gramalingam Jun 3, 2025
2b17200
Address PR feedaback
gramalingam Jun 3, 2025
556e7ff
Address PR feedback
gramalingam Jun 3, 2025
06d3e9a
Merge branch 'main' into rama/sdpa
gramalingam Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions onnxscript/rewriter/_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@
import onnxscript.rewriter._rewrite_rule as _rewrite_rule


class MatchFailureInfo:
"""Encapsulates information about a pattern match failure."""

def __init__(
self,
reason: str = "",
*failure_source: ir.Node | ir.Value,
):
self.reason = reason
self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source
assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), (
f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}"
)

def __str__(self):
return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})"

Check warning on line 34 in onnxscript/rewriter/_basics.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_basics.py#L34

Added line #L34 was not covered by tests


class MatchFailureError(MatchFailureInfo, Exception):
"""Exception raised when a pattern match fails.

This makes it easier to handle match failures in a compositional way,
for example, during the condition-checking phase of a pattern match.
It allows us to define utility functions without having to check for
and propagate match failures explicitly.
"""

def __init__(
self,
reason: str = "",
*failure_source: ir.Node | ir.Value,
):
MatchFailureInfo.__init__(self, reason, *failure_source)
Exception.__init__(self, reason)


class MatchResult:
"""The state object used by the pattern-matching algorithm.

Expand Down
19 changes: 19 additions & 0 deletions onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import onnxscript.ir as ir
import onnxscript.ir.passes.common as common_passes
from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchFailureError

Dim = Union[int, ir.SymbolicDim]

Expand All @@ -24,6 +25,24 @@
return True


def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]):
if val.shape is None:
raise MatchFailureError(f"The shape of {val} is unknown.", val)

Check warning on line 30 in onnxscript/rewriter/_fusion_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_fusion_utils.py#L30

Added line #L30 was not covered by tests
if val.shape.rank() != len(shape):
raise MatchFailureError(

Check warning on line 32 in onnxscript/rewriter/_fusion_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_fusion_utils.py#L32

Added line #L32 was not covered by tests
f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.",
val,
)
for i, (actual, expected) in enumerate(zip(val.shape, shape)):
if expected not in bindings:
bindings[expected] = actual # type: ignore[assignment]
elif actual != bindings[expected]:
raise MatchFailureError(
f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).",
val,
)


def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
"""
Apply the given fusion rules to the model and return the number of fusions applied.
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def try_rewrite(
if var.name is not None:
if var.name not in match.bindings:
match.bind(var.name, None)
check_match_result = self._condition_function(context, **match.bindings)
try:
check_match_result = self._condition_function(context, **match.bindings)
except _basics.MatchFailureError as e:
check_match_result = _basics.MatchResult()
check_match_result.fail(e.reason, list(e.failure_sources))
if not check_match_result:
# If check function was provided, but it failed, return the reason for failure to the tracer.
if isinstance(check_match_result, _basics.MatchResult):
Expand Down
13 changes: 13 additions & 0 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs):
"num_heads must be divisible by kv_num_heads"
)
self.num_groups = self.num_heads // self.kv_num_heads
self.total_seqlen = self.seqlen + self.past_seqlen

# Abbreviations
B = self.batchsize
Expand Down Expand Up @@ -311,12 +312,24 @@ def test_fusion(self):
onnx.TensorProto.FLOAT,
["B", self.seqlen, self.kv_num_heads, self.head_size],
)
key_transposed_value_info = onnx.helper.make_tensor_value_info(
"key_transposed",
onnx.TensorProto.FLOAT,
["B", self.num_heads, self.head_size, self.total_seqlen],
)
value_BHSDh_value_info = onnx.helper.make_tensor_value_info(
"value_BHSDh",
onnx.TensorProto.FLOAT,
["B", self.num_heads, self.total_seqlen, self.head_size],
)
source_model.graph.value_info.extend(
[
query_BHSDh_rope_value_info,
key_BHkvSDh_rope_value_info,
query_BSHDh_value_info,
key_BSHkvDh_value_info,
key_transposed_value_info,
value_BHSDh_value_info,
]
)

Expand Down
Loading
Loading