|
9 | 9 | import onnxscript.ir as ir
|
10 | 10 | from onnxscript.rewriter import _fusion_utils, pattern
|
11 | 11 |
|
12 |
| -""" |
13 |
| -The MultiHeadAttention pattern: generate an instance |
14 |
| - MHA (query, key, value, None, None, mask, past_key, past_value) |
15 |
| -where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv). |
16 |
| -The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias) |
17 |
| -must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh). |
18 |
| -
|
19 |
| -We use the following abbreviations for the dimensions: |
20 |
| -B: Batch size |
21 |
| -S: Sequence length |
22 |
| -D: input embedding dimension |
23 |
| -Dv: value hidden size (usually, Dv = D) |
24 |
| -H: number of heads |
25 |
| -Dh: head size or embedding dimension per head (usually, D = H * Dh) |
26 |
| -Skv: key/value sequence length |
27 |
| -St: total sequence length |
28 |
| -
|
29 |
| -In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). |
30 |
| -The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). |
31 |
| -""" |
| 12 | +valid_float_types = [ir.DataType.FLOAT, ir.DataType.FLOAT16] |
32 | 13 |
|
33 | 14 | Dim = Union[int, ir.SymbolicDim]
|
34 | 15 |
|
@@ -102,6 +83,13 @@ def check(
|
102 | 83 | def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
|
103 | 84 | return not _fusion_utils._check_shape(self.bindings, val, dims)
|
104 | 85 |
|
| 86 | + if query_matmul.dtype not in valid_float_types: |
| 87 | + return check_result.fail("Query is not a float or float16 type.", query_matmul) |
| 88 | + if key_matmul.dtype not in valid_float_types: |
| 89 | + return check_result.fail("Key is not a float or float16 type.", key_matmul) |
| 90 | + if value_matmul.dtype not in valid_float_types: |
| 91 | + return check_result.fail("Value is not a float or float16 type.", value_matmul) |
| 92 | + |
105 | 93 | if no_match(query_matmul, ["B", "S", "D"]):
|
106 | 94 | return check_result.fail(
|
107 | 95 | f"Shape mismatch: {query_matmul} does not match expected dimensions ['B', 'S', 'D']",
|
@@ -148,19 +136,26 @@ def rewrite(
|
148 | 136 | num_heads,
|
149 | 137 | **_,
|
150 | 138 | ):
|
151 |
| - if self._q_no_bias: |
152 |
| - q_bias = op.Constant( |
153 |
| - value=ir.tensor(numpy.zeros((self.Dh_q,), dtype=numpy.float32)) |
154 |
| - ) |
155 |
| - if self._k_no_bias: |
156 |
| - k_bias = op.Constant( |
157 |
| - value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=numpy.float32)) |
158 |
| - ) |
159 |
| - if self._v_no_bias: |
160 |
| - v_bias = op.Constant( |
161 |
| - value=ir.tensor(numpy.zeros((self.Dh_v,), dtype=numpy.float32)) |
162 |
| - ) |
163 |
| - bias = op.Concat(q_bias, k_bias, v_bias, axis=0) |
| 139 | + if self._q_no_bias and self._k_no_bias and self._v_no_bias: |
| 140 | + bias = None |
| 141 | + else: |
| 142 | + if self._q_no_bias: |
| 143 | + q_bias = op.Constant( |
| 144 | + value=ir.tensor( |
| 145 | + numpy.zeros((self.Dh_q,), dtype=query_matmul.dtype.numpy()) |
| 146 | + ) |
| 147 | + ) |
| 148 | + if self._k_no_bias: |
| 149 | + k_bias = op.Constant( |
| 150 | + value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=key_matmul.dtype.numpy())) |
| 151 | + ) |
| 152 | + if self._v_no_bias: |
| 153 | + v_bias = op.Constant( |
| 154 | + value=ir.tensor( |
| 155 | + numpy.zeros((self.Dh_v,), dtype=value_matmul.dtype.numpy()) |
| 156 | + ) |
| 157 | + ) |
| 158 | + bias = op.Concat(q_bias, k_bias, v_bias, axis=0) |
164 | 159 | return op.MultiHeadAttention(
|
165 | 160 | query_matmul,
|
166 | 161 | key_matmul,
|
|
0 commit comments