Skip to content

Commit c634313

Browse files
authored
[DRAFT] Generalize MHA pattern (#2092)
Generalize the MHA pattern (motivated by the Phi models). Specifically, we remove the initial MatMuls from the pattern (as being unnecessary). Phi uses packed MatMul (Q, K, and V are multiplied using a single MatMul and then sliced). However, this is not sufficient yet, since Phi also uses partial rotary-embedding, which is not yet supported by the RotaryEmbedding pattern. I will separately work on the extension to the RotaryEmbedding pattern to handle partial embedding.
1 parent db02e3f commit c634313

File tree

5 files changed

+176
-110
lines changed

5 files changed

+176
-110
lines changed

onnxscript/rewriter/_ir_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,17 @@ def has_rank(value: ir.Value | None, rank: int) -> bool:
124124
return False
125125
shape = value.shape
126126
return (shape is not None) and (shape.rank() == rank)
127+
128+
129+
def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None:
130+
"""Returns the value of the given dimension, or None if it is not statically known."""
131+
if value is None:
132+
return None
133+
shape = value.shape
134+
if shape is None:
135+
return None
136+
if dim < 0:
137+
dim += shape.rank()
138+
if dim < 0 or dim >= shape.rank():
139+
return None
140+
return shape[dim]

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,5 +304,6 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
304304
transpose_identity_rule,
305305
transpose_transpose_rule,
306306
unsqueeze_unsqueeze_rule,
307+
squeeze_reshape_1d_rule,
307308
]
308309
)

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import onnx
1010
import onnxruntime
11+
import packaging.version
1112

1213
import onnxscript.ir as ir
1314
import onnxscript.ir._io as io
@@ -21,6 +22,9 @@ def _save(model, modelpath):
2122
io.save(model, modelpath)
2223

2324

25+
ORT_VERSION = packaging.version.Version(onnxruntime.__version__)
26+
27+
2428
def ort_run(model_name: str, model, inputs):
2529
providers = ["CPUExecutionProvider"]
2630
with tempfile.TemporaryDirectory() as temp_dir:

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 145 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,36 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from typing import Sequence
5+
from typing import Sequence, Union
66

77
import onnxscript.ir as ir
8-
from onnxscript.rewriter import pattern
8+
from onnxscript.rewriter import _ir_utils, pattern
99

1010
"""
11-
The MultiHeadAttention pattern:
11+
The MultiHeadAttention pattern: generate an instance
12+
MHA (query, key, value, None, None, mask, past_key, past_value)
13+
where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv).
14+
The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias)
15+
must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh).
1216
17+
We use the following abbreviations for the dimensions:
1318
B: Batch size
1419
S: Sequence length
1520
D: input embedding dimension
21+
Dv: value hidden size (usually, Dv = D)
1622
H: number of heads
17-
d_h: head size (usually, D = H * d_h)
23+
Dh: head size or embedding dimension per head (usually, D = H * Dh)
24+
Skv: key/value sequence length
25+
St: total sequence length
1826
19-
thus, weights are usually of shape (D, D) and (D, D) and (D, D)
20-
21-
for each of Q, K, and V, we have the following pattern:
22-
MatMul (Input, W), producing output of shape (B, S, D)
23-
Reshape to produce a matrix of shape (B, S, H, d_h)
24-
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)
25-
26-
This is followed by a RotaryEmbedding pattern for Q and K
27-
28-
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)
29-
30-
The dot-product attention is then computed using SDPA.
31-
Finally, the output is transposed and reshaped back to (B, S, D) shape
27+
In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh).
28+
The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh).
3229
"""
3330

31+
Dim = Union[int, ir.SymbolicDim]
3432

35-
def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str]) -> bool:
33+
34+
def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
3635
if val.shape is None:
3736
return False
3837
if val.shape.rank() != len(shape):
@@ -46,131 +45,170 @@ def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str])
4645

4746

4847
class MultiHeadAttention(pattern.RewriteRuleClassBase):
49-
def __init__(self, name: str, *, use_2d_matmul: bool):
50-
super().__init__(name)
51-
self._use_2d_matmul = use_2d_matmul
52-
53-
def _compute_QKV(self, op, input, weight, reshape_var: str):
54-
"""Applied to generate each of Q, K, and V from input."""
55-
if self._use_2d_matmul:
56-
# Convert batched input of shape (B, S, D) to 2D input (B*S, D)
57-
input = op.Reshape(input, _allow_other_inputs=True)
58-
projected = op.MatMul(input, weight)
59-
if self._use_2d_matmul:
60-
# Convert 2D output back to batched output of shape (B, S, D)
61-
projected = op.Reshape(projected, _allow_other_inputs=True)
62-
# Reshape from (B, S, D) to (B, S, H, D/H)
63-
reshaped = op.Reshape(
64-
projected,
65-
_allow_other_inputs=True,
66-
_allow_other_attributes=True,
67-
_outputs=[reshape_var],
68-
)
69-
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
70-
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
71-
return transposed
48+
def __init__(self):
49+
super().__init__("MHA")
7250

7351
def pattern(
7452
self,
7553
op,
76-
input,
77-
query_weight,
78-
key_weight,
79-
value_weight,
80-
qkv_weight,
54+
query_BSD,
55+
key_BSD,
56+
value_BSD,
8157
mask,
82-
cos,
83-
sin,
8458
past_key,
8559
past_value,
8660
position_ids,
61+
cos,
62+
sin,
8763
):
88-
query = self._compute_QKV(op, input, query_weight, "query_mm_reshaped")
89-
key = self._compute_QKV(op, input, key_weight, "key_mm_reshaped")
90-
value = self._compute_QKV(op, input, value_weight, "value_mm_reshaped")
64+
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
65+
66+
# Reshape from (B, S, D) to (B, S, H, D/H)
67+
query_BSHDh = op.Reshape(
68+
query_BSD,
69+
_allow_other_inputs=True,
70+
_allow_other_attributes=True,
71+
_outputs=["query_BSHDh"],
72+
)
73+
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
74+
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
75+
76+
# Reshape from (B, S, D) to (B, S, H, D/H)
77+
key_BSHDh = op.Reshape(
78+
key_BSD,
79+
_allow_other_inputs=True,
80+
_allow_other_attributes=True,
81+
_outputs=["key_BSHDh"],
82+
)
83+
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84+
key_BHSDh = op.Transpose(key_BSHDh, perm=[0, 2, 1, 3])
85+
86+
# Reshape from (B, S, D) to (B, S, H, D/H)
87+
value_BSHDh = op.Reshape(
88+
value_BSD,
89+
_allow_other_inputs=True,
90+
_allow_other_attributes=True,
91+
_outputs=["value_BSHDh"],
92+
)
93+
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
94+
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
95+
96+
query_BHSDh_rope = op.RotaryEmbedding(
97+
query_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
98+
)
99+
key_BHSDh_rope = op.RotaryEmbedding(
100+
key_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
101+
)
91102

92-
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
103+
# Concatenate past_key cache and current key, and transpose to enable
104+
# dot-product attention computation.
93105

94-
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
95-
key_rope = op.Concat(past_key, key_rope, axis=-2)
96-
# Transpose last two axes of key_rope to compute dot-product via matmul.
97-
key_reshaped = op.Reshape(
98-
key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]
106+
key_seq = op.Concat(past_key, key_BHSDh_rope, axis=-2)
107+
# Transpose last two axes of key_seq to compute dot-product via matmul.
108+
key_seq_BH_Skv_Dh = op.Reshape(
109+
key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"]
99110
)
100-
key_reshaped_transposed = op.Transpose(key_reshaped, perm=[0, 2, 1])
101-
key_transposed = op.Reshape(
102-
key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]
111+
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
112+
key_seq_B_H_Dh_Skv = op.Reshape(
113+
key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"]
103114
)
104115

105-
value = op.Concat(past_value, value, axis=-2)
116+
# Concatenate past_value cache and current value
117+
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)
106118

107119
attention = op.SDPA(
108-
query_rope, key_transposed, value, mask, _domain="ai.onnxruntime.fusion"
120+
query_BHSDh_rope,
121+
key_seq_B_H_Dh_Skv,
122+
value_seq,
123+
mask,
124+
_domain="ai.onnxruntime.fusion",
109125
)
110-
# Transpose back to (B, S, H, D/H)
126+
127+
# Transpose attention back to (B, S, H, D/H)
111128
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3])
112129
# Reshape back to (B, S, D)
113130
attention_reshaped = op.Reshape(
114131
attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]
115132
)
116-
return attention_reshaped, key_rope, value
133+
return attention_reshaped, key_seq, value_seq
117134

118135
def check(
119136
self,
120137
op,
121-
query_mm_reshaped,
122-
key_mm_reshaped,
123-
value_mm_reshaped,
124-
key_reshaped,
125-
key_transposed,
126-
attention_reshaped,
138+
query_BSD,
139+
key_BSD,
140+
value_BSD,
141+
mask,
142+
past_key,
143+
past_value,
144+
query_BSHDh,
145+
key_BSHDh,
146+
value_BSHDh,
127147
**_,
128148
):
129-
bindings: dict[str, int] = {}
130-
status = (
131-
_check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
132-
and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"])
133-
and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"])
134-
and _check_shape(bindings, key_reshaped, ["B*H", "KVS", "d_h"])
135-
and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "KVS"])
136-
and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"])
137-
)
138-
if not status:
149+
bindings: dict[str, Dim] = {}
150+
151+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
152+
return not _check_shape(bindings, val, dims)
153+
154+
if no_match(query_BSD, ["B", "S", "D"]):
155+
return False
156+
if no_match(key_BSD, ["B", "Skv", "D"]):
157+
return False
158+
if no_match(value_BSD, ["B", "Skv", "D"]):
139159
return False
140-
# if bindings["B"] * bindings["H"] != bindings["B*H"]:
141-
# return False
142-
# if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
143-
# return False
160+
161+
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
162+
return False
163+
if no_match(past_value, ["B", "H", "Spast", "Dv"]):
164+
return False
165+
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
166+
return False
167+
if no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
168+
return False
169+
if no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
170+
return False
171+
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
172+
# But this also, unforunately, depends on ORT version.
173+
174+
# TODO: verify Reshapes:
175+
# eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
176+
# and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
177+
# or check Reshape's shape-input value
144178
return True
145179

146180
def rewrite(
147181
self,
148182
op,
149-
input,
150-
query_weight,
151-
key_weight,
152-
value_weight,
183+
query_BSD,
184+
key_BSD,
185+
value_BSD,
153186
mask,
154-
cos,
155-
sin,
156187
past_key,
157188
past_value,
189+
key_BSHDh,
158190
position_ids,
159-
query_mm_reshaped,
191+
cos,
192+
sin,
160193
**_,
161194
):
162-
num_heads = query_mm_reshaped.shape[2]
163-
query = op.MatMul(input, query_weight)
164-
key = op.MatMul(input, key_weight)
165-
value = op.MatMul(input, value_weight)
166-
167-
query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
168-
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
195+
num_heads = _ir_utils.get_dim(key_BSHDh, 2)
196+
if not isinstance(num_heads, int):
197+
return None
198+
199+
# Switch to 3D RotaryEmbedding
200+
# TODO: forward other attributes
201+
query_BSD_rope = op.RotaryEmbedding(
202+
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
203+
)
204+
key_BSD_rope = op.RotaryEmbedding(
205+
key_BSD, position_ids, cos, sin, _domain="com.microsoft"
206+
)
169207

170208
return op.MultiHeadAttention(
171-
query_rope,
172-
key_rope,
173-
value,
209+
query_BSD_rope,
210+
key_BSD_rope,
211+
value_BSD,
174212
None, # bias
175213
None, # key padding mask
176214
mask, # attention mask/bias
@@ -182,11 +220,15 @@ def rewrite(
182220
)
183221

184222

185-
_rule1 = MultiHeadAttention.rule("MHA_2dmm", use_2d_matmul=False)
223+
_rule1 = MultiHeadAttention.rule()
186224

187225
mha_rules = pattern.RewriteRuleSet([_rule1])
188226

189227

190-
def fuse_mha(model: ir.Model) -> int:
228+
def fuse_mha(model: ir.Model, *, debug: bool = False) -> int:
191229
count = mha_rules.apply_to_model(model)
230+
if debug and count == 0:
231+
tracer = pattern.MatchingTracer()
232+
mha_rules.apply_to_model(model, tracer=tracer)
233+
tracer.report()
192234
return count

0 commit comments

Comments
 (0)