17
17
18
18
import unittest
19
19
import warnings
20
- from typing import Callable , Optional , Sequence
20
+ from typing import Any , Callable , Optional , Sequence
21
21
22
22
import numpy as np
23
23
import onnx
@@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
55
55
return None
56
56
57
57
58
+ def _split_function_and_wrangler (
59
+ onnx_function_and_wrangler : Callable [..., Any ]
60
+ | tuple [Callable [..., Any ], Callable [..., Any ]]
61
+ ) -> tuple [Callable [..., Any ], Callable [..., Any ] | None ]:
62
+ """Splits a function with an optional input wrangler into a function and an input wrangler."""
63
+ if isinstance (onnx_function_and_wrangler , tuple ):
64
+ return onnx_function_and_wrangler
65
+
66
+ assert callable (onnx_function_and_wrangler )
67
+ return onnx_function_and_wrangler , None
68
+
69
+
58
70
class TestFunctionValidity (unittest .TestCase ):
59
71
def test_all_script_functions_are_onnx_functions (self ):
60
72
functions = set ()
61
73
for func_with_wrangler in ops_test_data .OPINFO_FUNCTION_MAPPING_SCRIPTED .values ():
62
- if isinstance (func_with_wrangler , tuple ):
63
- func = func_with_wrangler [0 ]
64
- else :
65
- func = func_with_wrangler
74
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
66
75
functions .add (func )
67
76
68
77
# TODO(justinchuby): Add from the registry
@@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self):
76
85
77
86
def test_all_trace_only_functions_are_not_onnx_functions (self ):
78
87
for func_with_wrangler in ops_test_data .OPINFO_FUNCTION_MAPPING_TRACE_ONLY .values ():
79
- if isinstance (func_with_wrangler , tuple ):
80
- func = func_with_wrangler [0 ]
81
- else :
82
- func = func_with_wrangler
88
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
83
89
if isinstance (func , onnxscript .OnnxFunction ):
84
90
raise AssertionError (
85
91
f"'{ func .name } ' is an OnnxFunction. "
@@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
95
101
"Function checker is not available before ONNX 1.14" ,
96
102
)
97
103
def test_script_function_passes_checker (self , _ , func_with_wrangler ):
98
- if isinstance (func_with_wrangler , tuple ):
99
- func = func_with_wrangler [0 ]
100
- else :
101
- func = func_with_wrangler
104
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
102
105
function_proto = func .to_function_proto ()
103
106
onnx .checker .check_function (function_proto ) # type: ignore[attr-defined]
104
107
@@ -127,16 +130,11 @@ def run_test_output_match(
127
130
)
128
131
129
132
onnx_function_and_wrangler = ops_test_data .OPINFO_FUNCTION_MAPPING [op .name ]
130
- input_wrangler = None
131
- if isinstance (onnx_function_and_wrangler , tuple ):
132
- # Obtain the input_wrangler that manipulates the OpInfo inputs
133
- # to match the aten operator signature
134
- # An example is nn.functional.upsample_nearest2d, which has a different signature
135
- # than the aten operator upsample_nearest2d
136
- onnx_function , input_wrangler = onnx_function_and_wrangler
137
- else :
138
- assert callable (onnx_function_and_wrangler )
139
- onnx_function = onnx_function_and_wrangler
133
+ # Obtain the input_wrangler that manipulates the OpInfo inputs
134
+ # to match the aten operator signature
135
+ # An example is nn.functional.upsample_nearest2d, which has a different signature
136
+ # than the aten operator upsample_nearest2d
137
+ onnx_function , input_wrangler = _split_function_and_wrangler (onnx_function_and_wrangler )
140
138
141
139
for i , cpu_sample in enumerate (samples ):
142
140
inputs = (cpu_sample .input , * cpu_sample .args )
0 commit comments