Skip to content

Commit 8a7738a

Browse files
committed
Implement split_function_and_wrangler | test(torchlib)
ghstack-source-id: c428e4d Pull Request resolved: microsoft/onnxscript#686
1 parent 1e2518b commit 8a7738a

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import unittest
1919
import warnings
20-
from typing import Callable, Optional, Sequence
20+
from typing import Any, Callable, Optional, Sequence
2121

2222
import numpy as np
2323
import onnx
@@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
5555
return None
5656

5757

58+
def _split_function_and_wrangler(
59+
onnx_function_and_wrangler: Callable[..., Any]
60+
| tuple[Callable[..., Any], Callable[..., Any]]
61+
) -> tuple[Callable[..., Any], Callable[..., Any] | None]:
62+
"""Splits a function with an optional input wrangler into a function and an input wrangler."""
63+
if isinstance(onnx_function_and_wrangler, tuple):
64+
return onnx_function_and_wrangler
65+
66+
assert callable(onnx_function_and_wrangler)
67+
return onnx_function_and_wrangler, None
68+
69+
5870
class TestFunctionValidity(unittest.TestCase):
5971
def test_all_script_functions_are_onnx_functions(self):
6072
functions = set()
6173
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values():
62-
if isinstance(func_with_wrangler, tuple):
63-
func = func_with_wrangler[0]
64-
else:
65-
func = func_with_wrangler
74+
func, _ = _split_function_and_wrangler(func_with_wrangler)
6675
functions.add(func)
6776

6877
# TODO(justinchuby): Add from the registry
@@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self):
7685

7786
def test_all_trace_only_functions_are_not_onnx_functions(self):
7887
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values():
79-
if isinstance(func_with_wrangler, tuple):
80-
func = func_with_wrangler[0]
81-
else:
82-
func = func_with_wrangler
88+
func, _ = _split_function_and_wrangler(func_with_wrangler)
8389
if isinstance(func, onnxscript.OnnxFunction):
8490
raise AssertionError(
8591
f"'{func.name}' is an OnnxFunction. "
@@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
95101
"Function checker is not available before ONNX 1.14",
96102
)
97103
def test_script_function_passes_checker(self, _, func_with_wrangler):
98-
if isinstance(func_with_wrangler, tuple):
99-
func = func_with_wrangler[0]
100-
else:
101-
func = func_with_wrangler
104+
func, _ = _split_function_and_wrangler(func_with_wrangler)
102105
function_proto = func.to_function_proto()
103106
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
104107

@@ -127,16 +130,11 @@ def run_test_output_match(
127130
)
128131

129132
onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name]
130-
input_wrangler = None
131-
if isinstance(onnx_function_and_wrangler, tuple):
132-
# Obtain the input_wrangler that manipulates the OpInfo inputs
133-
# to match the aten operator signature
134-
# An example is nn.functional.upsample_nearest2d, which has a different signature
135-
# than the aten operator upsample_nearest2d
136-
onnx_function, input_wrangler = onnx_function_and_wrangler
137-
else:
138-
assert callable(onnx_function_and_wrangler)
139-
onnx_function = onnx_function_and_wrangler
133+
# Obtain the input_wrangler that manipulates the OpInfo inputs
134+
# to match the aten operator signature
135+
# An example is nn.functional.upsample_nearest2d, which has a different signature
136+
# than the aten operator upsample_nearest2d
137+
onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler)
140138

141139
for i, cpu_sample in enumerate(samples):
142140
inputs = (cpu_sample.input, *cpu_sample.args)

0 commit comments

Comments
 (0)