Skip to content

Commit a98a10e

Browse files
Add fusion rules for com.microsoft.Attention (#2148)
#TODO - Find a model and create a test case to test this rewrite rule - Add rotaryembedding to pattern incorporating do_rotary
1 parent 2962a09 commit a98a10e

File tree

5 files changed

+460
-15
lines changed

5 files changed

+460
-15
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Sequence, Union
6+
7+
from onnxscript import ir
8+
9+
Dim = Union[int, ir.SymbolicDim]
10+
11+
12+
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
13+
if val.shape is None:
14+
return False
15+
if val.shape.rank() != len(shape):
16+
return False
17+
for actual, expected in zip(val.shape, shape):
18+
if expected not in bindings:
19+
bindings[expected] = actual # type: ignore[assignment]
20+
elif actual != bindings[expected]:
21+
return False
22+
return True

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
instance_to_group_normalization,
1313
softmax,
1414
)
15+
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
1516
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
1617
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
1718
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
@@ -53,6 +54,7 @@ def fuse_xformers(model: ir.Model) -> ir.Model:
5354
fuse_cos_sin_cache(model)
5455
fuse_sdpa(model)
5556
fuse_mha(model)
57+
fuse_attention(model)
5658
fuse_gelu(model)
5759
# Finally: inline any intermediate fusion functions introduced that were not
5860
# consumed by other fusions, and eliminate any remaining unused nodes.
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Sequence, Union
6+
7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import _fusion_utils, pattern
9+
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
13+
# TODO: Maybe add this check to utilities
14+
15+
16+
class AttentionFusion(pattern.RewriteRuleClassBase):
17+
def __init__(self, name, *, has_input_bias: bool, has_past: bool = False):
18+
super().__init__(name)
19+
# TODO: We can just pass bias to MultiHeadAttention
20+
# and let it handle the bias addition, once that pattern is added to MHA
21+
self._has_input_bias = has_input_bias
22+
self._has_past = has_past
23+
24+
def pattern(
25+
self,
26+
op,
27+
input,
28+
qkv_weight,
29+
qkv_bias,
30+
# mask_index,
31+
past,
32+
# attention_bias,
33+
num_heads,
34+
# scale,
35+
):
36+
projected = op.MatMul(input, qkv_weight)
37+
# Add bias if present
38+
if self._has_input_bias:
39+
projected = op.Add(projected, qkv_bias)
40+
41+
# Slice packed Matmul QKV into Q, K, and V
42+
# Q, K, and V are of shape (B, S, D)
43+
query_BSD = op.Slice(
44+
projected,
45+
_allow_other_inputs=True,
46+
_outputs=["query_mm_sliced"],
47+
)
48+
key_BSD = op.Slice(
49+
projected,
50+
_allow_other_inputs=True,
51+
_outputs=["key_mm_sliced"],
52+
)
53+
value_BSD = op.Slice(
54+
projected,
55+
_allow_other_inputs=True,
56+
_outputs=["value_mm_sliced"],
57+
)
58+
59+
# TODO: Add other attributes
60+
61+
if self._has_past:
62+
# Split past into past_key and past_value
63+
# past_key and past_value are of shape (B, H, S, D/H)
64+
past_key = op.Slice(
65+
past,
66+
_allow_other_inputs=True,
67+
_outputs=["past_key_sliced"],
68+
)
69+
past_key = op.Squeeze(past_key, [0])
70+
past_value = op.Slice(
71+
past,
72+
_allow_other_inputs=True,
73+
_outputs=["past_value_sliced"],
74+
)
75+
past_value = op.Squeeze(past_value, [0])
76+
77+
attention, present_key, present_value = op.MultiHeadAttention(
78+
query_BSD,
79+
key_BSD,
80+
value_BSD,
81+
None, # bias
82+
None, # key_padding_mask
83+
None, # attention_bias,
84+
past_key,
85+
past_value,
86+
num_heads=num_heads,
87+
# scale=scale,
88+
_domain="com.microsoft",
89+
_outputs=3,
90+
)
91+
# Concat present_key and present_value to form present
92+
present_key = op.Unsqueeze(present_key, [0])
93+
present_value = op.Unsqueeze(present_value, [0])
94+
present = op.Concat(present_key, present_value, axis=0)
95+
# Return present output first as it captures the complete pattern graph
96+
return present, attention
97+
else:
98+
attention = op.MultiHeadAttention(
99+
query_BSD,
100+
key_BSD,
101+
value_BSD,
102+
# bias
103+
# key_padding_mask
104+
# attention_bias,
105+
# past_key
106+
# past_value
107+
num_heads=num_heads,
108+
# scale=scale,
109+
_domain="com.microsoft",
110+
_outputs=1,
111+
)
112+
return attention
113+
114+
def check(
115+
self,
116+
op,
117+
input,
118+
qkv_weight,
119+
qkv_bias,
120+
query_mm_sliced,
121+
key_mm_sliced,
122+
value_mm_sliced,
123+
**_,
124+
):
125+
check_result = pattern.MatchResult()
126+
self.bindings: dict[str, Dim] = {}
127+
128+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
129+
return not _fusion_utils._check_shape(self.bindings, val, dims)
130+
131+
if no_match(input, ["B", "S", "D"]):
132+
return check_result.fail(
133+
f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']",
134+
input,
135+
)
136+
if no_match(qkv_weight, ["D", "Dh"]):
137+
return check_result.fail(
138+
f"Shape mismatch: {qkv_weight} does not match expected dimensions ['D', 'Dh']",
139+
qkv_weight,
140+
)
141+
if no_match(qkv_bias, ["Dh"]):
142+
return check_result.fail(
143+
f"Shape mismatch: {qkv_bias} does not match expected dimensions ['Dh']",
144+
qkv_bias,
145+
)
146+
if no_match(query_mm_sliced, ["B", "S", "Dh_q"]):
147+
return check_result.fail(
148+
f"Shape mismatch: {query_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_q']",
149+
query_mm_sliced,
150+
)
151+
if no_match(key_mm_sliced, ["B", "S", "Dh_k"]):
152+
return check_result.fail(
153+
f"Shape mismatch: {key_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_k']",
154+
key_mm_sliced,
155+
)
156+
if no_match(value_mm_sliced, ["B", "S", "Dh_v"]):
157+
return check_result.fail(
158+
f"Shape mismatch: {value_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_v']",
159+
value_mm_sliced,
160+
)
161+
162+
# Ensure Dh = Dh_q + Dh_k + Dh_v
163+
Dh = self.bindings.get("Dh")
164+
Dh_q = self.bindings.get("Dh_q")
165+
Dh_k = self.bindings.get("Dh_k")
166+
Dh_v = self.bindings.get("Dh_v")
167+
168+
if (
169+
not isinstance(Dh, int)
170+
or not isinstance(Dh_q, int)
171+
or not isinstance(Dh_k, int)
172+
or not isinstance(Dh_v, int)
173+
):
174+
return check_result.fail(
175+
"Could not determine the hidden sizes of query, key, and value.",
176+
)
177+
178+
if Dh != Dh_q + Dh_k + Dh_v: # type: ignore[operator]
179+
return check_result.fail(
180+
f"Hidden size of query, key and value do not add up to hidden size: {Dh} != {Dh_q} + {Dh_k} + {Dh_v}",
181+
)
182+
183+
# TODO: Add mask check once mask is added to the pattern
184+
return check_result
185+
186+
def rewrite(
187+
self,
188+
op,
189+
input,
190+
qkv_weight,
191+
qkv_bias,
192+
# mask_index,
193+
past,
194+
# attention_bias,
195+
num_heads,
196+
# scale,
197+
**_,
198+
):
199+
# Use bindings to get the values of Dh_q, Dh_k, and Dh_v
200+
# and construct qkv_hidden_sizes
201+
Dh_q = self.bindings.get("Dh_q")
202+
Dh_k = self.bindings.get("Dh_k")
203+
Dh_v = self.bindings.get("Dh_v")
204+
qkv_hidden_sizes = [Dh_q, Dh_k, Dh_v]
205+
206+
if self._has_past:
207+
attention, present = op.Attention(
208+
input,
209+
qkv_weight,
210+
qkv_bias,
211+
None, # mask_index
212+
past,
213+
# attention_bias,
214+
# past_sequence_length
215+
num_heads=num_heads,
216+
qkv_hidden_sizes=qkv_hidden_sizes,
217+
# scale=scale,
218+
_domain="com.microsoft",
219+
_outputs=2,
220+
)
221+
# Use same output ordering as in pattern
222+
return present, attention
223+
else:
224+
return op.Attention(
225+
input,
226+
qkv_weight,
227+
qkv_bias,
228+
# mask_index
229+
# past
230+
# attention_bias,
231+
# past_sequence_length
232+
num_heads=num_heads,
233+
qkv_hidden_sizes=qkv_hidden_sizes,
234+
# scale=scale,
235+
_domain="com.microsoft",
236+
_outputs=1,
237+
)
238+
239+
240+
attention = AttentionFusion.rule(
241+
"attention",
242+
has_input_bias=False,
243+
has_past=False,
244+
)
245+
attention_with_bias = AttentionFusion.rule(
246+
"attention_with_bias",
247+
has_input_bias=True,
248+
has_past=False,
249+
)
250+
attention_with_past = AttentionFusion.rule(
251+
"attention_with_past",
252+
has_input_bias=False,
253+
has_past=True,
254+
)
255+
attention_with_bias_and_past = AttentionFusion.rule(
256+
"attention_with_bias_and_past",
257+
has_input_bias=True,
258+
has_past=True,
259+
)
260+
261+
attention_rules = pattern.RewriteRuleSet(
262+
[
263+
attention,
264+
attention_with_bias,
265+
attention_with_past,
266+
attention_with_bias_and_past,
267+
]
268+
)
269+
270+
271+
def fuse_attention(model: ir.Model, *, debug: bool = False) -> int:
272+
count = attention_rules.apply_to_model(model)
273+
if debug and count == 0:
274+
tracer = pattern.MatchingTracer()
275+
attention_rules.apply_to_model(model, tracer=tracer)
276+
tracer.report()
277+
return count

0 commit comments

Comments
 (0)