18
18
19
19
20
20
@script ()
21
- def _test_script_onnx_default (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
21
+ def _test_script_onnx_default (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
22
22
gelu_add = op .Add (x , y )
23
23
return op .Gelu (gelu_add )
24
24
25
25
26
26
@script ()
27
- def _test_script_onnx_none (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
27
+ def _test_script_onnx_none (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
28
28
gelu_add = op .Add (x , y )
29
29
return op .Gelu (gelu_add , approximate = "none" )
30
30
31
31
32
32
@script ()
33
- def _test_script_onnx_unsupported (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
33
+ def _test_script_msft_op (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
34
34
gelu_add = op .Add (x , y )
35
- return op .Gelu (gelu_add , approximate = "tanh" )
35
+ return msft_op .Gelu (gelu_add )
36
+
37
+
38
+ @script ()
39
+ def _test_script_reversed_order (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
40
+ gelu_add = op .Add (y , x )
41
+ return op .Gelu (gelu_add )
36
42
37
43
38
44
@script ()
39
- def _test_script_msft_op (x : FLOAT [10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
45
+ def _test_script_onnx_unsupported (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
40
46
gelu_add = op .Add (x , y )
41
- return msft_op .Gelu (gelu_add )
47
+ return op .Gelu (gelu_add , approximate = "tanh" )
48
+
49
+
50
+ @script ()
51
+ def _test_script_shape_unsupported (x : FLOAT [10 , 10 ], y : FLOAT [10 ]) -> FLOAT [10 ]:
52
+ gelu_add = op .Add (x , x )
53
+ return op .Gelu (gelu_add )
42
54
43
55
44
56
class BiasGeluFusionTest (unittest .TestCase ):
@@ -54,7 +66,7 @@ def _check(
54
66
optimize (model )
55
67
56
68
input = {
57
- "x" : np .random .randn (10 ).astype (np .float32 ),
69
+ "x" : np .random .randn (10 , 10 ).astype (np .float32 ),
58
70
"y" : np .random .randn (10 ).astype (np .float32 ),
59
71
}
60
72
original_output = test_utils .ort_run ("Original" , model , input )
@@ -73,6 +85,7 @@ def _check(
73
85
("with_onnx_op_default" , _test_script_onnx_default , 1 , "BiasGelu" ),
74
86
("with_onnx_op_none" , _test_script_onnx_none , 1 , "BiasGelu" ),
75
87
("with_contrib_op" , _test_script_msft_op , 1 , "BiasGelu" ),
88
+ ("reversed_order" , _test_script_reversed_order , 1 , "BiasGelu" ),
76
89
]
77
90
)
78
91
def test_bias_gelu_fusion (
@@ -87,6 +100,7 @@ def test_bias_gelu_fusion(
87
100
@parameterized .parameterized .expand (
88
101
[
89
102
("approximate_tanh" , _test_script_onnx_unsupported , 2 , "Add" ),
103
+ ("unsupported_shape" , _test_script_shape_unsupported , 2 , "Add" ),
90
104
]
91
105
)
92
106
def test_bias_gelu_fusion_unsupported_attr (
0 commit comments