16
16
from __future__ import annotations
17
17
18
18
import unittest
19
- from typing import Callable , Optional , Sequence , Tuple
19
+ from typing import Any , Callable , Optional , Sequence , Tuple
20
20
21
21
import numpy as np
22
22
import onnx
@@ -56,14 +56,23 @@ def _should_skip_xfail_test_sample(
56
56
return None , None
57
57
58
58
59
+ def _split_function_and_wrangler (
60
+ onnx_function_and_wrangler : Callable [..., Any ]
61
+ | tuple [Callable [..., Any ], Callable [..., Any ]]
62
+ ) -> tuple [Callable [..., Any ], Callable [..., Any ] | None ]:
63
+ """Splits a function with an optional input wrangler into a function and an input wrangler."""
64
+ if isinstance (onnx_function_and_wrangler , tuple ):
65
+ return onnx_function_and_wrangler
66
+
67
+ assert callable (onnx_function_and_wrangler )
68
+ return onnx_function_and_wrangler , None
69
+
70
+
59
71
class TestFunctionValidity (unittest .TestCase ):
60
72
def test_all_script_functions_are_onnx_functions (self ):
61
73
functions = set ()
62
74
for func_with_wrangler in ops_test_data .OPINFO_FUNCTION_MAPPING_SCRIPTED .values ():
63
- if isinstance (func_with_wrangler , tuple ):
64
- func = func_with_wrangler [0 ]
65
- else :
66
- func = func_with_wrangler
75
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
67
76
functions .add (func )
68
77
69
78
# TODO(justinchuby): Add from the registry
@@ -77,10 +86,7 @@ def test_all_script_functions_are_onnx_functions(self):
77
86
78
87
def test_all_trace_only_functions_are_not_onnx_functions (self ):
79
88
for func_with_wrangler in ops_test_data .OPINFO_FUNCTION_MAPPING_TRACE_ONLY .values ():
80
- if isinstance (func_with_wrangler , tuple ):
81
- func = func_with_wrangler [0 ]
82
- else :
83
- func = func_with_wrangler
89
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
84
90
if isinstance (func , onnxscript .OnnxFunction ):
85
91
raise AssertionError (
86
92
f"'{ func .name } ' is an OnnxFunction. "
@@ -96,10 +102,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
96
102
"Function checker is not available before ONNX 1.14" ,
97
103
)
98
104
def test_script_function_passes_checker (self , _ , func_with_wrangler ):
99
- if isinstance (func_with_wrangler , tuple ):
100
- func = func_with_wrangler [0 ]
101
- else :
102
- func = func_with_wrangler
105
+ func , _ = _split_function_and_wrangler (func_with_wrangler )
103
106
function_proto = func .to_function_proto ()
104
107
onnx .checker .check_function (function_proto ) # type: ignore[attr-defined]
105
108
@@ -128,16 +131,11 @@ def run_test_output_match(
128
131
)
129
132
130
133
onnx_function_and_wrangler = ops_test_data .OPINFO_FUNCTION_MAPPING [op .name ]
131
- input_wrangler = None
132
- if isinstance (onnx_function_and_wrangler , tuple ):
133
- # Obtain the input_wrangler that manipulates the OpInfo inputs
134
- # to match the aten operator signature
135
- # An example is nn.functional.upsample_nearest2d, which has a different signature
136
- # than the aten operator upsample_nearest2d
137
- onnx_function , input_wrangler = onnx_function_and_wrangler
138
- else :
139
- assert callable (onnx_function_and_wrangler )
140
- onnx_function = onnx_function_and_wrangler
134
+ # Obtain the input_wrangler that manipulates the OpInfo inputs
135
+ # to match the aten operator signature
136
+ # An example is nn.functional.upsample_nearest2d, which has a different signature
137
+ # than the aten operator upsample_nearest2d
138
+ onnx_function , input_wrangler = _split_function_and_wrangler (onnx_function_and_wrangler )
141
139
142
140
for i , cpu_sample in enumerate (samples ):
143
141
inputs = (cpu_sample .input , * cpu_sample .args )
0 commit comments