diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index eaa6200888..5763f71c08 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # mypy: disable-error-code=misc +# mypy: disable-error-code=arg-type # mypy: disable-error-code=type-arg # mypy: disable-error-code=valid-type # mypy: disable-error-code=assignment @@ -17,6 +18,8 @@ from typing import Any, Optional, Sequence from onnxscript import INT64, TensorType +from onnxscript.function_libs.torch_aten.typing import TFloat +from onnxscript.onnx_opset import default_opset as op def aten_abs(self: TensorType) -> TensorType: @@ -4109,10 +4112,10 @@ def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) raise NotImplementedError() -def aten_selu(self: TensorType) -> TensorType: +def aten_selu(self: TFloat) -> TensorType: # selu(Tensor self) -> Tensor - raise NotImplementedError() + return op.Selu(self) def aten_set_data(self: TensorType, new_data: TensorType) -> Any: diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 7ad51fc9b4..288f89b330 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # mypy: disable-error-code=misc +# mypy: disable-error-code=arg-type # mypy: disable-error-code=type-arg # mypy: disable-error-code=valid-type # mypy: disable-error-code=assignment @@ -12,11 +13,19 @@ - All functions should not have the script() decorator. This is because we want to delay the compilation of the function. """ + +# pylint: disable=unused-argument + from __future__ import annotations from typing import Optional, Sequence +from beartype.vale import Is +from typing_extensions import Annotated + from onnxscript import INT64, TensorType +from onnxscript.function_libs.torch_aten.typing import TFloat +from onnxscript.onnx_opset import default_opset as op def aten_adaptive_avg_pool2d(self: TensorType, output_size: INT64) -> TensorType: @@ -185,11 +194,16 @@ def aten_cross_entropy_loss( def aten_elu( - self: TensorType, alpha: float = 1, scale: float = 1, input_scale: float = 1 + self: TFloat, + alpha: float = 1.0, + scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, + input_scale: Annotated[float, Is[lambda x: x == 1.0]] = 1.0, ) -> TensorType: # elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor - raise NotImplementedError() + # del scale + # del input_scale + return op.Elu(self, alpha=alpha) def aten_elu_backward( @@ -773,10 +787,11 @@ def aten_reflection_pad3d_backward( raise NotImplementedError() -def aten_relu6(self: TensorType) -> TensorType: +# TODO(justinchuby): Use TFloat as return type +def aten_relu6(self: TFloat) -> TensorType: # relu6(Tensor self) -> Tensor - raise NotImplementedError() + return op.Min(op.Relu(self), op.Constant(value_float=6.0)) # type: ignore[arg-type] def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py new file mode 100644 index 0000000000..85ebee1836 --- /dev/null +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -0,0 +1,256 @@ +"""Test op correctness by comparing with PyTorch results.""" +from __future__ import annotations + +import copy +import dataclasses +import unittest +from typing import Any, Callable, Collection, Iterable, Optional, Sequence, TypeVar + +import numpy as np +import onnxruntime.capi.onnxruntime_pybind11_state +import torch +from torch.testing._internal import common_device_type, common_methods_invocations +from torch.testing._internal.opinfo import core as opinfo_core + +import onnxscript +from onnxscript.function_libs.torch_aten.ops import core as core_ops +from onnxscript.function_libs.torch_aten.ops import nn as nn_ops + +T = TypeVar("T") + +SUPPORTED_DTYPES = ( + # Boolean + torch.bool, + # Integers + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + # Floating types + torch.float16, + torch.float32, + torch.float64, +) + +# Convenience tuples for creating dtype lists when skipping or xfailing tests + +BOOL_TYPES = (torch.bool,) + +INT_TYPES = ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, +) + +FLOAT_TYPES = ( + torch.float16, + torch.float32, + torch.float64, +) + + +def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]: + """Returns all dtypes except the ones specified.""" + return tuple(dtype for dtype in SUPPORTED_DTYPES if dtype not in dtypes) + + +@dataclasses.dataclass +class DecorateMeta: + """A dataclass for storing information about a test case to skip or xfail. + + Adapted from functorch: functorch/test/common_utils.py + """ + + op_name: str + variant_name: str + decorator: Callable[..., Any] + dtypes: Optional[Collection[torch.dtype]] + reason: str + + +def xfail( + op_name: str, + variant_name: str = "", + *, + dtypes: Optional[Collection[torch.dtype]] = None, + reason: Optional[str] = None, +): + """Expects an OpInfo test to fail. + + Args: + op_name: The name of the operator. + variant_name: Optional OpInfo variant_test_name. + dtypes: The dtypes to expect the failure. + reason: The reason for the failure. + """ + if reason is None: + raise ValueError("Please specify a reason.") + return DecorateMeta( + op_name=op_name, + variant_name=variant_name, + decorator=unittest.expectedFailure, + dtypes=dtypes, + reason=reason, + ) + + +def skip( + op_name: str, + variant_name: str = "", + *, + dtypes: Optional[Collection[torch.dtype]] = None, + reason: Optional[str] = None, +): + """Skips an OpInfo test. + + Args: + op_name: The name of the operator. + variant_name: Optional OpInfo variant_test_name. + dtypes: The dtypes to skip. + reason: The reason for skipping. + """ + if reason is None: + raise ValueError("Please specify a reason.") + return DecorateMeta( + op_name=op_name, + variant_name=variant_name, + decorator=unittest.skip(f"Don't care: {reason}"), + dtypes=dtypes, + reason=reason, + ) + + +def add_decorate_info( + all_opinfos: Sequence[opinfo_core.OpInfo], + test_class_name: str, + base_test_name: str, + skip_or_xfails: Iterable[DecorateMeta], +) -> Callable[[T], T]: + """Decorates OpInfo tests with decorators based on the skip_or_xfails list.""" + ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} + for decorate_meta in skip_or_xfails: + opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) + assert ( + opinfo is not None + ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + decorators = list(opinfo.decorators) + new_decorator = opinfo_core.DecorateInfo( + decorate_meta.decorator, + test_class_name, + base_test_name, + dtypes=decorate_meta.dtypes, + ) + decorators.append(new_decorator) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + + return wrapped + + +# Modify this section ########################################################## + +# Ops to be tested for numerical consistency between onnx and pytorch +OPINFO_FUNCTION_MAPPING = { + "nn.functional.elu": nn_ops.aten_elu, + "nn.functional.relu6": nn_ops.aten_relu6, + "nn.functional.selu": core_ops.aten_selu, +} + +TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) + +EXPECTED_SKIPS_OR_FAILS = ( + xfail( + "nn.functional.elu", + dtypes=dtypes_except(torch.float16, torch.float32), + reason="ONNX Runtime doesn't support float64 for Elu", + ), + xfail( + "nn.functional.relu6", + dtypes=dtypes_except(torch.float16, torch.float32), + reason="ONNX Runtime doesn't support float64 for Relu", + ), + xfail( + "nn.functional.selu", + dtypes=dtypes_except(torch.float16, torch.float32), + reason="ONNX Runtime doesn't support float64 for Selu", + ), +) +# END OF SECTION TO MODIFY ##################################################### + + +OPS_DB = copy.deepcopy(common_methods_invocations.op_db) + + +class TestOutputConsistency(unittest.TestCase): + """Test output consistency between exported ONNX models and PyTorch eager mode. + + This is a parameterized test suite. + """ + + def setUp(self) -> None: + torch.manual_seed(42) + np.random.seed(42) + + @common_device_type.ops( # type: ignore[misc] + [info for info in OPS_DB if info.name in TESTED_OPS], + allowed_dtypes=SUPPORTED_DTYPES, + ) + @add_decorate_info( + OPS_DB, + "TestOutputConsistency", + "test_output_match", + skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, + ) + def test_output_match(self, device: str, dtype: torch.dtype, op): + """Base test method for testing each opset, used by instantiate_device_type_tests.""" + # device is provided by instantiate_device_type_tests, but we only want to run in cpu. + assert device == "cpu" + + samples = op.sample_inputs( + device, + dtype, + requires_grad=False, + ) + + onnx_function = OPINFO_FUNCTION_MAPPING[op.name] + scripted_function = onnxscript.script()(onnx_function) + + for (i, cpu_sample) in enumerate(samples): + inputs = (cpu_sample.input, *cpu_sample.args) + # Provide the repr to subtest because tensors are not serializable in parallel test runs + with self.subTest( + sample_num=i, + inputs=repr(inputs), + kwargs=repr(cpu_sample.kwargs), + ): + input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] + torch_output = op(*inputs, **cpu_sample.kwargs) + try: + function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) + # pylint: disable=c-extension-no-member + except onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: + self.skipTest( + f"ONNX Runtime doesn't support running {op.name} with dtype {dtype}", + ) + # pylint: enable=c-extension-no-member + + # Use torch testing to ensure dtypes and shapes match + torch.testing.assert_close( + torch.tensor(function_output), + torch_output, + ) + + +common_device_type.instantiate_device_type_tests( + TestOutputConsistency, globals(), only_for="cpu" +) + + +if __name__ == "__main__": + unittest.main() diff --git a/requirements-dev.txt b/requirements-dev.txt index 0ee6feff7a..a2c46de42c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,15 +12,19 @@ sphinx-gallery pydata_sphinx_theme # ATen lib +beartype types-PyYAML +typing_extensions # Testing -pytest!=7.1.0 -pytest-cov +expecttest +parameterized pytest-azurepipelines +pytest-cov pytest-subtests pytest-xdist -parameterized +pytest!=7.1.0 +pyyaml torch # Lint