Skip to content

Implement split_function_and_wrangler | test(torchlib) #686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest
import warnings
from typing import Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence

import numpy as np
import onnx
Expand Down Expand Up @@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]:
return None


def _split_function_and_wrangler(
onnx_function_and_wrangler: Callable[..., Any]
| tuple[Callable[..., Any], Callable[..., Any]]
) -> tuple[Callable[..., Any], Callable[..., Any] | None]:
"""Splits a function with an optional input wrangler into a function and an input wrangler."""
if isinstance(onnx_function_and_wrangler, tuple):
return onnx_function_and_wrangler

assert callable(onnx_function_and_wrangler)
return onnx_function_and_wrangler, None


class TestFunctionValidity(unittest.TestCase):
def test_all_script_functions_are_onnx_functions(self):
functions = set()
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values():
if isinstance(func_with_wrangler, tuple):
func = func_with_wrangler[0]
else:
func = func_with_wrangler
func, _ = _split_function_and_wrangler(func_with_wrangler)
functions.add(func)

# TODO(justinchuby): Add from the registry
Expand All @@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self):

def test_all_trace_only_functions_are_not_onnx_functions(self):
for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values():
if isinstance(func_with_wrangler, tuple):
func = func_with_wrangler[0]
else:
func = func_with_wrangler
func, _ = _split_function_and_wrangler(func_with_wrangler)
if isinstance(func, onnxscript.OnnxFunction):
raise AssertionError(
f"'{func.name}' is an OnnxFunction. "
Expand All @@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
"Function checker is not available before ONNX 1.14",
)
def test_script_function_passes_checker(self, _, func_with_wrangler):
if isinstance(func_with_wrangler, tuple):
func = func_with_wrangler[0]
else:
func = func_with_wrangler
func, _ = _split_function_and_wrangler(func_with_wrangler)
function_proto = func.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]

Expand Down Expand Up @@ -127,16 +130,11 @@ def run_test_output_match(
)

onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name]
input_wrangler = None
if isinstance(onnx_function_and_wrangler, tuple):
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, input_wrangler = onnx_function_and_wrangler
else:
assert callable(onnx_function_and_wrangler)
onnx_function = onnx_function_and_wrangler
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler)

for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
Expand Down