Skip to content

Commit 489e6b7

Browse files
A couple of extensions to MHA fusion (#2106)
A couple of extensions to MHA fusion: * One deals with variations in positions-ids. The challenge is to verify that the position-ids used in the two RotaryEmbedding are the same. In some models, they are the same value (by reference). In some models, there is some duplication of the code in computing the 2D position-id from 1D position-id. If we had a common-sub-expression identification/elimination, that would help. For now, just handling it in the pattern itself. * The second deals with variations in how the last two axes of key are transposed. Some models reshape the input tensor to 3D and do the transpose, while some directly transpose a 4D tensor. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent f6efc7c commit 489e6b7

File tree

1 file changed

+36
-13
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+36
-13
lines changed

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
4545

4646

4747
class MultiHeadAttention(pattern.RewriteRuleClassBase):
48-
def __init__(self):
49-
super().__init__("MHA")
48+
def __init__(self, name, *, transpose_4d: bool):
49+
super().__init__(name)
50+
self._transpose_4d = transpose_4d
5051

5152
def pattern(
5253
self,
@@ -93,25 +94,42 @@ def pattern(
9394
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
9495
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
9596

97+
# This is workaround for examples where there is a duplication of Unsqueeze op
98+
# to generate a 2D positions-ids from a 1D position-ids. This can be eliminated
99+
# if we have CSE-optimization to eliminate the duplicate Unsqueeze ops.
100+
# For now, same flag (transpose_4d) controls this variation. A different flag
101+
# can be added if we see instances that mix the two.
102+
if self._transpose_4d:
103+
position_ids_q = op.Unsqueeze(position_ids, [0])
104+
position_ids_k = op.Unsqueeze(position_ids, [0])
105+
else:
106+
position_ids_q = position_ids
107+
position_ids_k = position_ids
108+
96109
query_BHSDh_rope = op.RotaryEmbedding(
97-
query_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
110+
query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft"
98111
)
112+
99113
key_BHSDh_rope = op.RotaryEmbedding(
100-
key_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
114+
key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft"
101115
)
102116

103117
# Concatenate past_key cache and current key, and transpose to enable
104118
# dot-product attention computation.
105119

106120
key_seq = op.Concat(past_key, key_BHSDh_rope, axis=-2)
107121
# Transpose last two axes of key_seq to compute dot-product via matmul.
108-
key_seq_BH_Skv_Dh = op.Reshape(
109-
key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"]
110-
)
111-
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
112-
key_seq_B_H_Dh_Skv = op.Reshape(
113-
key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"]
114-
)
122+
if self._transpose_4d:
123+
key_seq_B_H_Dh_Skv = op.Transpose(key_seq, perm=[0, 1, 3, 2])
124+
else:
125+
# Transpose after converting to 3D
126+
key_seq_BH_Skv_Dh = op.Reshape(
127+
key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"]
128+
)
129+
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
130+
key_seq_B_H_Dh_Skv = op.Reshape(
131+
key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"]
132+
)
115133

116134
# Concatenate past_value cache and current value
117135
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)
@@ -198,6 +216,10 @@ def rewrite(
198216

199217
# Switch to 3D RotaryEmbedding
200218
# TODO: forward other attributes
219+
220+
if self._transpose_4d:
221+
zero_1d = op.Constant(value_ints=[0])
222+
position_ids = op.Unsqueeze(position_ids, zero_1d)
201223
query_BSD_rope = op.RotaryEmbedding(
202224
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
203225
)
@@ -220,9 +242,10 @@ def rewrite(
220242
)
221243

222244

223-
_rule1 = MultiHeadAttention.rule()
245+
_mha_4d_transpose = MultiHeadAttention.rule("MHA_4D_Transpose", transpose_4d=True)
246+
_mha_3d_transpose = MultiHeadAttention.rule("MHA_3D_Transpose", transpose_4d=False)
224247

225-
mha_rules = pattern.RewriteRuleSet([_rule1])
248+
mha_rules = pattern.RewriteRuleSet([_mha_4d_transpose, _mha_3d_transpose])
226249

227250

228251
def fuse_mha(model: ir.Model, *, debug: bool = False) -> int:

0 commit comments

Comments
 (0)