From 02ec5377adcfddaf8ccf5095a292ee1f4f930c9b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 03:58:08 +0000 Subject: [PATCH] Implement split_function_and_wrangler | test(torchlib) [ghstack-poisoned] --- .../tests/function_libs/torch_lib/ops_test.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 98e89d509b..140f2118cc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -17,7 +17,7 @@ import unittest import warnings -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence import numpy as np import onnx @@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: return None +def _split_function_and_wrangler( + onnx_function_and_wrangler: Callable[..., Any] + | tuple[Callable[..., Any], Callable[..., Any]] +) -> tuple[Callable[..., Any], Callable[..., Any] | None]: + """Splits a function with an optional input wrangler into a function and an input wrangler.""" + if isinstance(onnx_function_and_wrangler, tuple): + return onnx_function_and_wrangler + + assert callable(onnx_function_and_wrangler) + return onnx_function_and_wrangler, None + + class TestFunctionValidity(unittest.TestCase): def test_all_script_functions_are_onnx_functions(self): functions = set() for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) functions.add(func) # TODO(justinchuby): Add from the registry @@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self): def test_all_trace_only_functions_are_not_onnx_functions(self): for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) if isinstance(func, onnxscript.OnnxFunction): raise AssertionError( f"'{func.name}' is an OnnxFunction. " @@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self): "Function checker is not available before ONNX 1.14", ) def test_script_function_passes_checker(self, _, func_with_wrangler): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) function_proto = func.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @@ -127,16 +130,11 @@ def run_test_output_match( ) onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name] - input_wrangler = None - if isinstance(onnx_function_and_wrangler, tuple): - # Obtain the input_wrangler that manipulates the OpInfo inputs - # to match the aten operator signature - # An example is nn.functional.upsample_nearest2d, which has a different signature - # than the aten operator upsample_nearest2d - onnx_function, input_wrangler = onnx_function_and_wrangler - else: - assert callable(onnx_function_and_wrangler) - onnx_function = onnx_function_and_wrangler + # Obtain the input_wrangler that manipulates the OpInfo inputs + # to match the aten operator signature + # An example is nn.functional.upsample_nearest2d, which has a different signature + # than the aten operator upsample_nearest2d + onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler) for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args)