Skip to content

Commit 1a0329a

Browse files
add attention tests
1 parent 7c45032 commit 1a0329a

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

onnxscript/rewriter/ort_fusions/attention_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import onnxscript
1212
import onnxscript.ir as ir
13+
import onnxscript.optimizer
1314
import onnxscript.rewriter.ort_fusions._core as xformers
1415
from onnxscript import FLOAT, script
1516
from onnxscript import opset18 as op
1617
from onnxscript.ir.passes.common import shape_inference
1718
from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run
19+
from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test
1820

1921
msft_op = onnxscript.values.Opset("com.microsoft", 1)
2022

@@ -155,6 +157,35 @@ def test_model_with_mha(self, name, with_past):
155157
new_outputs = ort_run("optimized", model, inputs)
156158
assert_allclose(new_outputs, original_outputs)
157159

160+
def test_whisper_encoder(self):
161+
# Generate model
162+
whisper_encoder = whisper_encoder_test()
163+
model = whisper_encoder.get_onnx_model()
164+
onnxscript.optimizer.optimize(model)
165+
166+
test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION
167+
if test_with_ort:
168+
# Run model
169+
inputs = whisper_encoder.get_ort_inputs()
170+
original_outputs = ort_run("original", model, inputs)
171+
172+
# Fuse SDPA and MHA
173+
sdpa_count = xformers.fuse_sdpa(model)
174+
self.assertGreater(sdpa_count, 0)
175+
model = shape_inference.infer_shapes(model)
176+
mha_count = xformers.fuse_mha(model)
177+
self.assertGreater(mha_count, 0)
178+
fused_mha_bias_count = xformers.fuse_mha_bias(model)
179+
self.assertGreater(fused_mha_bias_count, 0)
180+
attention_count = xformers.fuse_attention(model)
181+
self.assertGreater(attention_count, 0)
182+
onnxscript.optimizer.optimize(model)
183+
184+
if test_with_ort:
185+
# Run model again
186+
new_outputs = ort_run("optimized", model, inputs)
187+
assert_allclose(new_outputs, original_outputs)
188+
158189

159190
if __name__ == "__main__":
160191
unittest.main()

onnxscript/rewriter/ort_fusions/models/_whisper_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main_graph(
147147
epsilon=9.999999747378752e-06,
148148
axis=-1,
149149
)
150-
return layer_norm_2
150+
return add_170
151151

152152
model = main_graph.to_model_proto()
153153
return model

0 commit comments

Comments
 (0)