|
9 | 9 | import unittest
|
10 | 10 |
|
11 | 11 | import numpy
|
12 |
| -from parameterized import parameterized |
| 12 | +import parameterized |
13 | 13 |
|
14 | 14 | import onnxscript.ir as ir
|
15 | 15 | import onnxscript.optimizer
|
|
28 | 28 | SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
|
29 | 29 |
|
30 | 30 |
|
| 31 | +@script() |
| 32 | +def _unmasked_pre_div_sdpa_script(query, key, value, mask): |
| 33 | + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) |
| 34 | + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) |
| 35 | + scaled_query = op.Div(query, divisor) |
| 36 | + scaled_key = op.Div(key_transposed, divisor) |
| 37 | + attn_score = op.MatMul(scaled_query, scaled_key) |
| 38 | + attn_weight = op.Softmax(attn_score, axis=-1) |
| 39 | + attn_output = op.MatMul(attn_weight, value) |
| 40 | + return attn_output |
| 41 | + |
| 42 | + |
| 43 | +@script() |
| 44 | +def _unmasked_pre_mul_sdpa_script(query, key, value, mask): |
| 45 | + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) |
| 46 | + multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) |
| 47 | + scaled_query = op.Mul(query, multiplier) |
| 48 | + scaled_key = op.Mul(key_transposed, multiplier) |
| 49 | + attn_score = op.MatMul(scaled_query, scaled_key) |
| 50 | + attn_weight = op.Softmax(attn_score, axis=-1) |
| 51 | + attn_output = op.MatMul(attn_weight, value) |
| 52 | + return attn_output |
| 53 | + |
| 54 | + |
| 55 | +@script() |
| 56 | +def _unmasked_post_div_sdpa_script(query, key, value, mask): |
| 57 | + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) |
| 58 | + divisor = op.Constant(value_float=SCALE_FACTOR) |
| 59 | + attn_score = op.MatMul(query, key_transposed) |
| 60 | + scaled_attn_score = op.Div(attn_score, divisor) |
| 61 | + attn_weight = op.Softmax(scaled_attn_score, axis=-1) |
| 62 | + attn_output = op.MatMul(attn_weight, value) |
| 63 | + return attn_output |
| 64 | + |
| 65 | + |
| 66 | +@script() |
| 67 | +def _unmasked_post_mul_sdpa_script(query, key, value, mask): |
| 68 | + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) |
| 69 | + multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) |
| 70 | + attn_score = op.MatMul(query, key_transposed) |
| 71 | + scaled_attn_score = op.Mul(attn_score, multiplier) |
| 72 | + attn_weight = op.Softmax(scaled_attn_score, axis=-1) |
| 73 | + attn_output = op.MatMul(attn_weight, value) |
| 74 | + return attn_output |
| 75 | + |
| 76 | + |
31 | 77 | @script()
|
32 | 78 | def _masked_pre_div_sdpa_script(query, key, value, mask):
|
33 | 79 | key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
|
@@ -105,8 +151,12 @@ def get_ort_inputs(self):
|
105 | 151 |
|
106 | 152 |
|
107 | 153 | class TestSDPAFusion(unittest.TestCase):
|
108 |
| - @parameterized.expand( |
| 154 | + @parameterized.parameterized.expand( |
109 | 155 | [
|
| 156 | + ("unmasked_pre_div", _unmasked_pre_div_sdpa_script), |
| 157 | + ("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script), |
| 158 | + ("unmasked_post_div", _unmasked_post_div_sdpa_script), |
| 159 | + ("unmasked_post_mul", _unmasked_post_mul_sdpa_script), |
110 | 160 | ("pre_div", _masked_pre_div_sdpa_script),
|
111 | 161 | ("pre_mul", _masked_pre_mul_sdpa_script),
|
112 | 162 | ("post_div", _masked_post_div_sdpa_script),
|
|
0 commit comments