Skip to content

Commit 1d3bc5b

Browse files
fuse mha bias modifs
1 parent 1a0329a commit 1d3bc5b

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

onnxscript/rewriter/ort_fusions/fuse_mha_bias.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,7 @@
99
import onnxscript.ir as ir
1010
from onnxscript.rewriter import _fusion_utils, pattern
1111

12-
"""
13-
The MultiHeadAttention pattern: generate an instance
14-
MHA (query, key, value, None, None, mask, past_key, past_value)
15-
where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv).
16-
The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias)
17-
must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh).
18-
19-
We use the following abbreviations for the dimensions:
20-
B: Batch size
21-
S: Sequence length
22-
D: input embedding dimension
23-
Dv: value hidden size (usually, Dv = D)
24-
H: number of heads
25-
Dh: head size or embedding dimension per head (usually, D = H * Dh)
26-
Skv: key/value sequence length
27-
St: total sequence length
28-
29-
In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh).
30-
The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh).
31-
"""
12+
valid_float_types = [ir.DataType.FLOAT, ir.DataType.FLOAT16]
3213

3314
Dim = Union[int, ir.SymbolicDim]
3415

@@ -102,6 +83,13 @@ def check(
10283
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
10384
return not _fusion_utils._check_shape(self.bindings, val, dims)
10485

86+
if query_matmul.dtype not in valid_float_types:
87+
return check_result.fail("Query is not a float or float16 type.", query_matmul)
88+
if key_matmul.dtype not in valid_float_types:
89+
return check_result.fail("Key is not a float or float16 type.", key_matmul)
90+
if value_matmul.dtype not in valid_float_types:
91+
return check_result.fail("Value is not a float or float16 type.", value_matmul)
92+
10593
if no_match(query_matmul, ["B", "S", "D"]):
10694
return check_result.fail(
10795
f"Shape mismatch: {query_matmul} does not match expected dimensions ['B', 'S', 'D']",
@@ -148,19 +136,26 @@ def rewrite(
148136
num_heads,
149137
**_,
150138
):
151-
if self._q_no_bias:
152-
q_bias = op.Constant(
153-
value=ir.tensor(numpy.zeros((self.Dh_q,), dtype=numpy.float32))
154-
)
155-
if self._k_no_bias:
156-
k_bias = op.Constant(
157-
value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=numpy.float32))
158-
)
159-
if self._v_no_bias:
160-
v_bias = op.Constant(
161-
value=ir.tensor(numpy.zeros((self.Dh_v,), dtype=numpy.float32))
162-
)
163-
bias = op.Concat(q_bias, k_bias, v_bias, axis=0)
139+
if self._q_no_bias and self._k_no_bias and self._v_no_bias:
140+
bias = None
141+
else:
142+
if self._q_no_bias:
143+
q_bias = op.Constant(
144+
value=ir.tensor(
145+
numpy.zeros((self.Dh_q,), dtype=query_matmul.dtype.numpy())
146+
)
147+
)
148+
if self._k_no_bias:
149+
k_bias = op.Constant(
150+
value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=key_matmul.dtype.numpy()))
151+
)
152+
if self._v_no_bias:
153+
v_bias = op.Constant(
154+
value=ir.tensor(
155+
numpy.zeros((self.Dh_v,), dtype=value_matmul.dtype.numpy())
156+
)
157+
)
158+
bias = op.Concat(q_bias, k_bias, v_bias, axis=0)
164159
return op.MultiHeadAttention(
165160
query_matmul,
166161
key_matmul,

0 commit comments

Comments
 (0)