|
10 | 10 |
|
11 | 11 | import onnxscript
|
12 | 12 | import onnxscript.ir as ir
|
| 13 | +import onnxscript.optimizer |
13 | 14 | import onnxscript.rewriter.ort_fusions._core as xformers
|
14 | 15 | from onnxscript import FLOAT, script
|
15 | 16 | from onnxscript import opset18 as op
|
16 | 17 | from onnxscript.ir.passes.common import shape_inference
|
17 | 18 | from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run
|
| 19 | +from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test |
18 | 20 |
|
19 | 21 | msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
20 | 22 |
|
@@ -155,6 +157,35 @@ def test_model_with_mha(self, name, with_past):
|
155 | 157 | new_outputs = ort_run("optimized", model, inputs)
|
156 | 158 | assert_allclose(new_outputs, original_outputs)
|
157 | 159 |
|
| 160 | + def test_whisper_encoder(self): |
| 161 | + # Generate model |
| 162 | + whisper_encoder = whisper_encoder_test() |
| 163 | + model = whisper_encoder.get_onnx_model() |
| 164 | + onnxscript.optimizer.optimize(model) |
| 165 | + |
| 166 | + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION |
| 167 | + if test_with_ort: |
| 168 | + # Run model |
| 169 | + inputs = whisper_encoder.get_ort_inputs() |
| 170 | + original_outputs = ort_run("original", model, inputs) |
| 171 | + |
| 172 | + # Fuse SDPA and MHA |
| 173 | + sdpa_count = xformers.fuse_sdpa(model) |
| 174 | + self.assertGreater(sdpa_count, 0) |
| 175 | + model = shape_inference.infer_shapes(model) |
| 176 | + mha_count = xformers.fuse_mha(model) |
| 177 | + self.assertGreater(mha_count, 0) |
| 178 | + fused_mha_bias_count = xformers.fuse_mha_bias(model) |
| 179 | + self.assertGreater(fused_mha_bias_count, 0) |
| 180 | + attention_count = xformers.fuse_attention(model) |
| 181 | + self.assertGreater(attention_count, 0) |
| 182 | + onnxscript.optimizer.optimize(model) |
| 183 | + |
| 184 | + if test_with_ort: |
| 185 | + # Run model again |
| 186 | + new_outputs = ort_run("optimized", model, inputs) |
| 187 | + assert_allclose(new_outputs, original_outputs) |
| 188 | + |
158 | 189 |
|
159 | 190 | if __name__ == "__main__":
|
160 | 191 | unittest.main()
|
0 commit comments