Skip to content

Commit 7ff5466

Browse files
committed
Implement split_function_and_wrangler | test(torchlib)
ghstack-source-id: 266c0f9 Pull Request resolved: #688
1 parent eb58e24 commit 7ff5466

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
import unittest
19-
from typing import Callable, Optional, Sequence, Tuple
19+
from typing import Any, Callable, Optional, Sequence, Tuple
2020

2121
import numpy as np
2222
import onnx
@@ -56,14 +56,23 @@ def _should_skip_xfail_test_sample(
5656
return None, None
5757

5858

59+
def _split_function_and_wrangler(
60+
onnx_function_and_wrangler: Callable[..., Any]
61+
| tuple[Callable[..., Any], Callable[..., Any]]
62+
) -> tuple[Callable[..., Any], Callable[..., Any] | None]:
63+
"""Splits a function with an optional input wrangler into a function and an input wrangler."""
64+
if isinstance(onnx_function_and_wrangler, tuple):
65+
return onnx_function_and_wrangler
66+
67+
assert callable(onnx_function_and_wrangler)
68+
return onnx_function_and_wrangler, None
69+
70+
5971
class TestFunctionValidity(unittest.TestCase):
6072
def test_all_script_functions_are_onnx_functions(self):
6173
functions = set()
6274
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values():
63-
if isinstance(func_with_wrangler, tuple):
64-
func = func_with_wrangler[0]
65-
else:
66-
func = func_with_wrangler
75+
func, _ = _split_function_and_wrangler(func_with_wrangler)
6776
functions.add(func)
6877

6978
# TODO(justinchuby): Add from the registry
@@ -77,10 +86,7 @@ def test_all_script_functions_are_onnx_functions(self):
7786

7887
def test_all_trace_only_functions_are_not_onnx_functions(self):
7988
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values():
80-
if isinstance(func_with_wrangler, tuple):
81-
func = func_with_wrangler[0]
82-
else:
83-
func = func_with_wrangler
89+
func, _ = _split_function_and_wrangler(func_with_wrangler)
8490
if isinstance(func, onnxscript.OnnxFunction):
8591
raise AssertionError(
8692
f"'{func.name}' is an OnnxFunction. "
@@ -96,10 +102,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
96102
"Function checker is not available before ONNX 1.14",
97103
)
98104
def test_script_function_passes_checker(self, _, func_with_wrangler):
99-
if isinstance(func_with_wrangler, tuple):
100-
func = func_with_wrangler[0]
101-
else:
102-
func = func_with_wrangler
105+
func, _ = _split_function_and_wrangler(func_with_wrangler)
103106
function_proto = func.to_function_proto()
104107
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
105108

@@ -128,16 +131,11 @@ def run_test_output_match(
128131
)
129132

130133
onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name]
131-
input_wrangler = None
132-
if isinstance(onnx_function_and_wrangler, tuple):
133-
# Obtain the input_wrangler that manipulates the OpInfo inputs
134-
# to match the aten operator signature
135-
# An example is nn.functional.upsample_nearest2d, which has a different signature
136-
# than the aten operator upsample_nearest2d
137-
onnx_function, input_wrangler = onnx_function_and_wrangler
138-
else:
139-
assert callable(onnx_function_and_wrangler)
140-
onnx_function = onnx_function_and_wrangler
134+
# Obtain the input_wrangler that manipulates the OpInfo inputs
135+
# to match the aten operator signature
136+
# An example is nn.functional.upsample_nearest2d, which has a different signature
137+
# than the aten operator upsample_nearest2d
138+
onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler)
141139

142140
for i, cpu_sample in enumerate(samples):
143141
inputs = (cpu_sample.input, *cpu_sample.args)

0 commit comments

Comments
 (0)