Skip to content

Commit e49620a

Browse files
gramalingamCopilotgithub-advanced-security[bot]titaiwangms
authored
SDPA fusion cleanup (#2352)
Remove the need for many different rules for SDPA fusion by (a) Using pattern-disjunction, and (b) Simplifying the handling of scaling factors which can occur in several forms (using either multiplication or division, either separately to query and/or key, or to the product of query and key). Also: simplify the way shapes are checked and error messages are generated. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 4e526f7 commit e49620a

File tree

6 files changed

+226
-180
lines changed

6 files changed

+226
-180
lines changed

onnxscript/rewriter/_basics.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,42 @@
1616
import onnxscript.rewriter._rewrite_rule as _rewrite_rule
1717

1818

19+
class MatchFailureInfo:
20+
"""Encapsulates information about a pattern match failure."""
21+
22+
def __init__(
23+
self,
24+
reason: str = "",
25+
*failure_source: ir.Node | ir.Value,
26+
):
27+
self.reason = reason
28+
self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source
29+
assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), (
30+
f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}"
31+
)
32+
33+
def __str__(self):
34+
return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})"
35+
36+
37+
class MatchFailureError(MatchFailureInfo, Exception):
38+
"""Exception raised when a pattern match fails.
39+
40+
This makes it easier to handle match failures in a compositional way,
41+
for example, during the condition-checking phase of a pattern match.
42+
It allows us to define utility functions without having to check for
43+
and propagate match failures explicitly.
44+
"""
45+
46+
def __init__(
47+
self,
48+
reason: str = "",
49+
*failure_source: ir.Node | ir.Value,
50+
):
51+
MatchFailureInfo.__init__(self, reason, *failure_source)
52+
Exception.__init__(self, reason)
53+
54+
1955
class MatchResult:
2056
"""The state object used by the pattern-matching algorithm.
2157

onnxscript/rewriter/_fusion_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import onnxscript.ir as ir
88
import onnxscript.ir.passes.common as common_passes
99
from onnxscript.rewriter import pattern
10+
from onnxscript.rewriter._basics import MatchFailureError
1011

1112
Dim = Union[int, ir.SymbolicDim]
1213

@@ -24,6 +25,24 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
2425
return True
2526

2627

28+
def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]):
29+
if val.shape is None:
30+
raise MatchFailureError(f"The shape of {val} is unknown.", val)
31+
if val.shape.rank() != len(shape):
32+
raise MatchFailureError(
33+
f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.",
34+
val,
35+
)
36+
for i, (actual, expected) in enumerate(zip(val.shape, shape)):
37+
if expected not in bindings:
38+
bindings[expected] = actual # type: ignore[assignment]
39+
elif actual != bindings[expected]:
40+
raise MatchFailureError(
41+
f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).",
42+
val,
43+
)
44+
45+
2746
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
2847
"""
2948
Apply the given fusion rules to the model and return the number of fusions applied.

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ def try_rewrite(
174174
if var.name is not None:
175175
if var.name not in match.bindings:
176176
match.bind(var.name, None)
177-
check_match_result = self._condition_function(context, **match.bindings)
177+
try:
178+
check_match_result = self._condition_function(context, **match.bindings)
179+
except _basics.MatchFailureError as e:
180+
check_match_result = _basics.MatchResult()
181+
check_match_result.fail(e.reason, list(e.failure_sources))
178182
if not check_match_result:
179183
# If check function was provided, but it failed, return the reason for failure to the tracer.
180184
if isinstance(check_match_result, _basics.MatchResult):

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs):
4444
"num_heads must be divisible by kv_num_heads"
4545
)
4646
self.num_groups = self.num_heads // self.kv_num_heads
47+
self.total_seqlen = self.seqlen + self.past_seqlen
4748

4849
# Abbreviations
4950
B = self.batchsize
@@ -311,12 +312,24 @@ def test_fusion(self):
311312
onnx.TensorProto.FLOAT,
312313
["B", self.seqlen, self.kv_num_heads, self.head_size],
313314
)
315+
key_transposed_value_info = onnx.helper.make_tensor_value_info(
316+
"key_transposed",
317+
onnx.TensorProto.FLOAT,
318+
["B", self.num_heads, self.head_size, self.total_seqlen],
319+
)
320+
value_BHSDh_value_info = onnx.helper.make_tensor_value_info(
321+
"value_BHSDh",
322+
onnx.TensorProto.FLOAT,
323+
["B", self.num_heads, self.total_seqlen, self.head_size],
324+
)
314325
source_model.graph.value_info.extend(
315326
[
316327
query_BHSDh_rope_value_info,
317328
key_BHkvSDh_rope_value_info,
318329
query_BSHDh_value_info,
319330
key_BSHkvDh_value_info,
331+
key_transposed_value_info,
332+
value_BHSDh_value_info,
320333
]
321334
)
322335

0 commit comments

Comments
 (0)