4
4
import unittest
5
5
6
6
import numpy as np
7
+ import parameterized
7
8
8
9
import onnxscript
9
10
import onnxscript .ir as ir
10
11
import onnxscript .rewriter .ort_fusions ._test_utils as test_utils
11
- from onnxscript import FLOAT , script
12
- from onnxscript import opset18 as op
12
+ from onnxscript import FLOAT , OnnxFunction , script
13
+ from onnxscript import opset20 as op
13
14
from onnxscript .optimizer import optimize , remove_unused_nodes
14
15
from onnxscript .rewriter .ort_fusions .bias_gelu import fuse_bias_gelu
15
16
16
17
msft_op = onnxscript .values .Opset ("com.microsoft" , 1 )
17
18
18
19
20
+ @script ()
21
+ def _test_script_onnx_default (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
22
+ gelu_add = op .Add (x , y )
23
+ return op .Gelu (gelu_add )
24
+
25
+
26
+ @script ()
27
+ def _test_script_onnx_none (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
28
+ gelu_add = op .Add (x , y )
29
+ return op .Gelu (gelu_add , approximate = "none" )
30
+
31
+
32
+ @script ()
33
+ def _test_script_onnx_unsupported (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
34
+ gelu_add = op .Add (x , y )
35
+ return op .Gelu (gelu_add , approximate = "tanh" )
36
+
37
+
38
+ @script ()
39
+ def _test_script_msft_op (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
40
+ gelu_add = op .Add (x , y )
41
+ return msft_op .Gelu (gelu_add )
42
+
43
+
19
44
class BiasGeluFusionTest (unittest .TestCase ):
20
- def test_bias_gelu_fusion (self ):
21
- @script ()
22
- def bias_gelu_model (x , y ):
23
- gelu_add = op .Add (x , y )
24
- gelu = msft_op .Gelu (gelu_add )
25
- return gelu
26
-
27
- model_proto = bias_gelu_model .to_model_proto (
28
- input_types = [FLOAT [10 ], FLOAT [10 ]],
29
- output_types = [FLOAT [10 ]],
30
- ir_version = 10 ,
31
- )
45
+ def _check (
46
+ self ,
47
+ test_data_constructor : OnnxFunction ,
48
+ expected_graph_len : int ,
49
+ expected_op_type : str ,
50
+ ):
51
+ """Helper method to run a fusion test scenario."""
52
+ model_proto = test_data_constructor .to_model_proto ()
32
53
model = ir .serde .deserialize_model (model_proto )
33
54
optimize (model )
34
55
@@ -41,12 +62,42 @@ def bias_gelu_model(x, y):
41
62
fuse_bias_gelu (model )
42
63
remove_unused_nodes (model )
43
64
44
- self .assertEqual (len (model .graph ), 1 )
45
- self .assertEqual (model .graph .node (0 ).op_type , "BiasGelu" )
65
+ self .assertEqual (len (model .graph ), expected_graph_len )
66
+ self .assertEqual (model .graph .node (0 ).op_type , expected_op_type )
46
67
47
68
optimized_output = test_utils .ort_run ("Optimized" , model , input )
48
69
test_utils .assert_allclose (original_output , optimized_output )
49
70
71
+ @parameterized .parameterized .expand (
72
+ [
73
+ ("with_onnx_op_default" , _test_script_onnx_default , 1 , "BiasGelu" ),
74
+ ("with_onnx_op_none" , _test_script_onnx_none , 1 , "BiasGelu" ),
75
+ ("with_contrib_op" , _test_script_msft_op , 1 , "BiasGelu" ),
76
+ ]
77
+ )
78
+ def test_bias_gelu_fusion (
79
+ self ,
80
+ _ ,
81
+ test_data_constructor : OnnxFunction ,
82
+ expected_graph_len : int ,
83
+ expected_op_type : str ,
84
+ ):
85
+ self ._check (test_data_constructor , expected_graph_len , expected_op_type )
86
+
87
+ @parameterized .parameterized .expand (
88
+ [
89
+ ("approximate_tanh" , _test_script_onnx_unsupported , 2 , "Add" ),
90
+ ]
91
+ )
92
+ def test_bias_gelu_fusion_unsupported_attr (
93
+ self ,
94
+ _ ,
95
+ test_data_constructor : OnnxFunction ,
96
+ expected_graph_len : int ,
97
+ expected_op_type : str ,
98
+ ):
99
+ self ._check (test_data_constructor , expected_graph_len , expected_op_type )
100
+
50
101
51
102
if __name__ == "__main__" :
52
103
unittest .main ()
0 commit comments