From 55f8dd91c65584fc9e9aa6227c01655ee3efc5f0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Dec 2022 19:29:24 +0000 Subject: [PATCH 01/18] fix: annotate script() To allow mypy to analyze typing for annotated functions. Otherwise it complains that "Untyped decorator makes function "ones_like" untyped [misc]" [ghstack-poisoned] --- onnxscript/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/main.py b/onnxscript/main.py index 25529e0532..7b95cd44ed 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -9,6 +9,7 @@ import inspect import sys import textwrap +from typing import Any, Callable, Optional import onnx.helper @@ -52,7 +53,9 @@ def script_check(f: ast.FunctionDef, opset, global_names, source, default_opset= return convert.top_level_stmt(f) -def script(opset=None, default_opset=None, **kwargs): +def script( + opset: Optional[values.Opset] = None, default_opset: Optional[Any] = None, **kwargs: Any +) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]: """Main decorator. Declares a function as an onnx function. Args: From 8a3a58757bc0ec12b3350aff05820e8a7db3cad4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Dec 2022 20:37:52 +0000 Subject: [PATCH 02/18] feat(atenlib): clamp, lt, gt [ghstack-poisoned] --- .../function_libs/torch_aten/ops/common.py | 13 ++++ .../function_libs/torch_aten/ops/core.py | 62 ++++++++++++++----- .../torch_aten/ops_correctness_test.py | 32 +++++++++- 3 files changed, 91 insertions(+), 16 deletions(-) create mode 100644 onnxscript/function_libs/torch_aten/ops/common.py diff --git a/onnxscript/function_libs/torch_aten/ops/common.py b/onnxscript/function_libs/torch_aten/ops/common.py new file mode 100644 index 0000000000..e433fab177 --- /dev/null +++ b/onnxscript/function_libs/torch_aten/ops/common.py @@ -0,0 +1,13 @@ +"""Commonly shared functions for the function library.""" +from __future__ import annotations + +import onnx.helper + +from onnxscript.onnx_opset import opset18 as op + + +def ones_like(x, onnx_dtype: int): + shape = op.Shape(x) + return op.ConstantOfShape( + shape, value=onnx.helper.make_tensor("one", onnx_dtype, [1], [1]) + ) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 7fd0593f3c..ce5920dd95 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -17,8 +17,11 @@ from typing import Any, Optional, Sequence +import onnx.helper + from onnxscript import BOOL, INT64 -from onnxscript.onnx_opset import default_opset as op +from onnxscript.function_libs.torch_aten.ops import common +from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -747,16 +750,31 @@ def aten_clamp( raise NotImplementedError() -def aten_clamp_max(self: TensorType, max: float) -> TensorType: +def aten_clamp_max_scalar(self, max_): # clamp_max(Tensor self, Scalar max) -> Tensor - raise NotImplementedError() + max_ = op.CastLike(max_, self) + return op.Clip(self, None, max_) -def aten_clamp_min(self: TensorType, min: float) -> TensorType: +def aten_clamp_max_tensor(self, max_): + # clamp_max(Tensor self, Scalar max) -> Tensor + + return op.Min(self, max_) + + +def aten_clamp_min_scalar(self, min_): # clamp_min(Tensor self, Scalar min) -> Tensor + # NOTE: min_ is a rank 0 tensor. + # TODO(justinchuby): Specify the type constraints. + min_ = op.CastLike(min_, self) + return op.Clip(self, min_, None) - raise NotImplementedError() + +def aten_clamp_min_tensor(self, min_): + # clamp_min(Tensor self, Tensor min) -> Tensor + # TODO(justinchuby): Specify the type constraints. + return op.Max(self, min_) def aten_clip( @@ -1958,10 +1976,12 @@ def aten_gru_cell( raise NotImplementedError() -def aten_gt(self: TensorType, other: TensorType) -> TensorType: +def aten_gt(self, other): # gt.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Input spec: non bool tensor + # Boolean inputs can be pre-casted by policy + return op.Greater(self, other) def aten_hamming_window(window_length: int) -> TensorType: @@ -2572,10 +2592,12 @@ def aten_lstm_mps_backward( raise NotImplementedError() -def aten_lt(self: TensorType, other: TensorType) -> TensorType: +def aten_lt(self, other): # lt.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Input spec: non bool tensor + # Boolean inputs can be pre-casted by policy + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -3440,10 +3462,17 @@ def aten_ones(size: INT64) -> TensorType: raise NotImplementedError() -def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +def aten_ones_like(self, dtype: Optional[int] = None): + """ones_like. + + Note: dtype is a torch enum. We need to convert it to ONNX dtype. + """ # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - raise NotImplementedError() + # TODO(justinchuby): Create a helper to convert torch dtype to ONNX dtype + if dtype is None: + dtype = onnx.TensorProto.FLOAT + return common.ones_like(self, dtype) def aten_or(self: TensorType, other: TensorType) -> TensorType: @@ -3916,10 +3945,13 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT raise NotImplementedError() -def aten_repeat(self: TensorType, repeats: INT64) -> TensorType: +def aten_repeat(self, repeats: INT64): # repeat(Tensor self, SymInt[] repeats) -> Tensor - raise NotImplementedError() + # FIXME(justinchuby): 'common' is not an instance of type Opset but . + shape = common.ones_like(repeats, onnx.TensorProto.INT64) + expanded = op.Expand(self, shape) + return op.Tile(expanded, repeats) def aten_repeat_interleave( @@ -4012,10 +4044,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te raise NotImplementedError() -def aten_round(self: TensorType) -> TensorType: +def aten_round(self): # round(Tensor self) -> Tensor - raise NotImplementedError() + return op.Round(self) def aten_row_indices(self: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 1284da351c..d5e370aabf 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -8,6 +8,7 @@ import numpy as np import onnxruntime.capi.onnxruntime_pybind11_state +import parameterized import torch from torch.testing._internal import common_device_type, common_methods_invocations from torch.testing._internal.opinfo import core as opinfo_core @@ -156,12 +157,17 @@ def wrapped(fn): # Modify this section ########################################################## # Ops to be tested for numerical consistency between onnx and pytorch -OPINFO_FUNCTION_MAPPING = { +# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py +OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { "add": core_ops.aten_add, + "gt": core_ops.aten_gt, + "lt": core_ops.aten_lt, "mul": core_ops.aten_mul, "nn.functional.elu": nn_ops.aten_elu, "nn.functional.relu6": nn_ops.aten_relu6, "nn.functional.selu": core_ops.aten_selu, + # "repeat": core_ops.aten_repeat, + "round": core_ops.aten_round, "sub": core_ops.aten_sub, } @@ -169,6 +175,8 @@ def wrapped(fn): EXPECTED_SKIPS_OR_FAILS = ( xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"), + xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), + xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"), xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"), xfail( "nn.functional.elu", @@ -185,6 +193,17 @@ def wrapped(fn): dtypes=dtypes_except(torch.float16, torch.float32), reason="ONNX Runtime doesn't support float64 for Selu", ), + xfail( + "round", + variant_name="", + dtypes=dtypes_except(*FLOAT_TYPES), + reason="Round is not defined on non-float tensors", + ), + xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"), + xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"), + xfail( + "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" + ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), ) # END OF SECTION TO MODIFY ##################################################### @@ -193,6 +212,17 @@ def wrapped(fn): OPS_DB = copy.deepcopy(common_methods_invocations.op_db) +class TestFunctionsCompilation(unittest.TestCase): + """Test all functions can be compiled.""" + + @parameterized.parameterized.expand( + list(OPINFO_FUNCTION_MAPPING.items()), + ) + def test_function_compiles(self, _, function): + compiled = onnxscript.script()(function) + compiled.to_function_proto() + + class TestOutputConsistency(unittest.TestCase): """Test output consistency between exported ONNX models and PyTorch eager mode. From f821b6ae182306d02c6f5870f9b19299196a5270 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Dec 2022 23:50:48 +0000 Subject: [PATCH 03/18] Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned] --- .../function_libs/torch_aten/ops/common.py | 11 +++++----- .../function_libs/torch_aten/ops/core.py | 21 ++++++++++++------- .../torch_aten/ops_correctness_test.py | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/common.py b/onnxscript/function_libs/torch_aten/ops/common.py index e433fab177..5d6676dcc6 100644 --- a/onnxscript/function_libs/torch_aten/ops/common.py +++ b/onnxscript/function_libs/torch_aten/ops/common.py @@ -1,13 +1,12 @@ """Commonly shared functions for the function library.""" from __future__ import annotations -import onnx.helper - +import onnxscript from onnxscript.onnx_opset import opset18 as op -def ones_like(x, onnx_dtype: int): +@onnxscript.script() +def ones_like(x, dtype: int): shape = op.Shape(x) - return op.ConstantOfShape( - shape, value=onnx.helper.make_tensor("one", onnx_dtype, [1], [1]) - ) + one_dtype = op.Cast(1, to=dtype) + return op.Expand(one_dtype, shape) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index ce5920dd95..da8ceeb3cf 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -19,12 +19,22 @@ import onnx.helper +import onnxscript from onnxscript import BOOL, INT64 from onnxscript.function_libs.torch_aten.ops import common from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +@onnxscript.script() +def _ones_like(x, dtype: int): + """Common function for ones_like.""" + # TODO(justinchuby): Put this in another module + shape = op.Shape(x) + one_dtype = op.Cast(1, to=dtype) + return op.Expand(one_dtype, shape) + + def aten_abs(self: TensorType) -> TensorType: # abs(Tensor self) -> Tensor @@ -3462,16 +3472,14 @@ def aten_ones(size: INT64) -> TensorType: raise NotImplementedError() -def aten_ones_like(self, dtype: Optional[int] = None): +def aten_ones_like(self, dtype: int = onnx.TensorProto.FLOAT): """ones_like. - Note: dtype is a torch enum. We need to convert it to ONNX dtype. + Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype + before calling this function. """ # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - # TODO(justinchuby): Create a helper to convert torch dtype to ONNX dtype - if dtype is None: - dtype = onnx.TensorProto.FLOAT return common.ones_like(self, dtype) @@ -3948,8 +3956,7 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT def aten_repeat(self, repeats: INT64): # repeat(Tensor self, SymInt[] repeats) -> Tensor - # FIXME(justinchuby): 'common' is not an instance of type Opset but . - shape = common.ones_like(repeats, onnx.TensorProto.INT64) + shape = _ones_like(repeats, onnx.TensorProto.INT64) expanded = op.Expand(self, shape) return op.Tile(expanded, repeats) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index d5e370aabf..770f016d4d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -166,7 +166,7 @@ def wrapped(fn): "nn.functional.elu": nn_ops.aten_elu, "nn.functional.relu6": nn_ops.aten_relu6, "nn.functional.selu": core_ops.aten_selu, - # "repeat": core_ops.aten_repeat, + "repeat": core_ops.aten_repeat, "round": core_ops.aten_round, "sub": core_ops.aten_sub, } From aecc14851604f84e7b599d7e2e080aa0416edaf2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 10 Dec 2022 00:33:02 +0000 Subject: [PATCH 04/18] Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned] --- .../function_libs/torch_aten/ops/common.py | 12 --------- .../function_libs/torch_aten/ops/core.py | 26 ++++++++----------- .../torch_aten/ops_correctness_test.py | 19 +++++++++++++- 3 files changed, 29 insertions(+), 28 deletions(-) delete mode 100644 onnxscript/function_libs/torch_aten/ops/common.py diff --git a/onnxscript/function_libs/torch_aten/ops/common.py b/onnxscript/function_libs/torch_aten/ops/common.py deleted file mode 100644 index 5d6676dcc6..0000000000 --- a/onnxscript/function_libs/torch_aten/ops/common.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Commonly shared functions for the function library.""" -from __future__ import annotations - -import onnxscript -from onnxscript.onnx_opset import opset18 as op - - -@onnxscript.script() -def ones_like(x, dtype: int): - shape = op.Shape(x) - one_dtype = op.Cast(1, to=dtype) - return op.Expand(one_dtype, shape) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index da8ceeb3cf..b4d10dd31c 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -19,22 +19,11 @@ import onnx.helper -import onnxscript from onnxscript import BOOL, INT64 -from onnxscript.function_libs.torch_aten.ops import common from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType -@onnxscript.script() -def _ones_like(x, dtype: int): - """Common function for ones_like.""" - # TODO(justinchuby): Put this in another module - shape = op.Shape(x) - one_dtype = op.Cast(1, to=dtype) - return op.Expand(one_dtype, shape) - - def aten_abs(self: TensorType) -> TensorType: # abs(Tensor self) -> Tensor @@ -3953,12 +3942,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT raise NotImplementedError() -def aten_repeat(self, repeats: INT64): +def aten_repeat(self, repeats: INT64["M"]): # repeat(Tensor self, SymInt[] repeats) -> Tensor - shape = _ones_like(repeats, onnx.TensorProto.INT64) - expanded = op.Expand(self, shape) - return op.Tile(expanded, repeats) + # FIXME(justinchuby): When repeats.shape == [0] + + # TODO(justinchuby): Make ones_like a function when onnxscript supports it + # shape = ones_like(repeats) := { + one = op.Constant(value_int=1) + repeats_shape = op.Shape(repeats) + shape = op.Expand(one, repeats_shape) + # } + self_expanded = op.Expand(self, shape) + return op.Tile(self_expanded, repeats) def aten_repeat_interleave( diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 770f016d4d..47ab32b10b 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -193,6 +193,7 @@ def wrapped(fn): dtypes=dtypes_except(torch.float16, torch.float32), reason="ONNX Runtime doesn't support float64 for Selu", ), + xfail("repeat", reason="fails when repeats is empty."), xfail( "round", variant_name="", @@ -223,6 +224,22 @@ def test_function_compiles(self, _, function): compiled.to_function_proto() +def _convert_tensor_to_numpy(input: Any) -> Any: + if isinstance(input, torch.Tensor): + return input.detach().cpu().numpy() + if isinstance(input, (tuple, list)): + if len(input) == 0: + return np.array((), dtype=np.int64) + if isinstance(input[0], torch.Tensor): + return [_convert_tensor_to_numpy(x) for x in input] + if isinstance(input[0], (int, float)): + # Just a tuple of numbers + return np.array(input) + return input + + return input + + class TestOutputConsistency(unittest.TestCase): """Test output consistency between exported ONNX models and PyTorch eager mode. @@ -265,7 +282,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): inputs=repr(inputs), kwargs=repr(cpu_sample.kwargs), ): - input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] + input_numpy = [_convert_tensor_to_numpy(x) for x in inputs] torch_output = op(*inputs, **cpu_sample.kwargs) try: function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) From 6555a55a80a35b35169bafbd9668ee46a555b8ff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 10 Dec 2022 01:10:07 +0000 Subject: [PATCH 05/18] Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned] --- .../torch_aten/ops_correctness_test.py | 44 ++++++++++++++++--- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 47ab32b10b..4ccf42e7f3 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -70,14 +70,15 @@ class DecorateMeta: decorator: Callable[..., Any] dtypes: Optional[Collection[torch.dtype]] reason: str + matcher: Optional[Callable[[Any], bool]] = None def xfail( op_name: str, variant_name: str = "", *, + reason: str, dtypes: Optional[Collection[torch.dtype]] = None, - reason: Optional[str] = None, ): """Expects an OpInfo test to fail. @@ -87,8 +88,6 @@ def xfail( dtypes: The dtypes to expect the failure. reason: The reason for the failure. """ - if reason is None: - raise ValueError("Please specify a reason.") return DecorateMeta( op_name=op_name, variant_name=variant_name, @@ -102,8 +101,9 @@ def skip( op_name: str, variant_name: str = "", *, + reason: str, dtypes: Optional[Collection[torch.dtype]] = None, - reason: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, ): """Skips an OpInfo test. @@ -112,15 +112,16 @@ def skip( variant_name: Optional OpInfo variant_test_name. dtypes: The dtypes to skip. reason: The reason for skipping. + matcher: A function that matches the test sample input. It is used only when + xfail is in the SKIP_SUBTESTS list. """ - if reason is None: - raise ValueError("Please specify a reason.") return DecorateMeta( op_name=op_name, variant_name=variant_name, decorator=unittest.skip(f"Don't care: {reason}"), dtypes=dtypes, reason=reason, + matcher=matcher, ) @@ -193,7 +194,7 @@ def wrapped(fn): dtypes=dtypes_except(torch.float16, torch.float32), reason="ONNX Runtime doesn't support float64 for Selu", ), - xfail("repeat", reason="fails when repeats is empty."), + # xfail("repeat", reason="fails when repeats is empty."), xfail( "round", variant_name="", @@ -207,6 +208,19 @@ def wrapped(fn): ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), ) + + +SKIP_SUBTESTS = ( + skip( + "repeat", + reason="repeating when input is a scalar and repeats is empty is not supported.", + matcher=lambda sample: sample.args[0] == (), + ), +) +OP_WITH_SKIPPED_SUBTESTS = frozenset( + meta.op_name for meta in SKIP_SUBTESTS +) + # END OF SECTION TO MODIFY ##################################################### @@ -240,6 +254,19 @@ def _convert_tensor_to_numpy(input: Any) -> Any: return input +def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: + """Returns a reason if a test sample should be skipped.""" + if op_name not in OP_WITH_SKIPPED_SUBTESTS: + return None + for decorator_meta in SKIP_SUBTESTS: + # Linear search on SKIP_SUBTESTS. That's fine because the list is small. + if decorator_meta.op_name == op_name: + assert decorator_meta.matcher is not None, "Matcher must be defined" + if decorator_meta.matcher(sample): + return decorator_meta.reason + return None + + class TestOutputConsistency(unittest.TestCase): """Test output consistency between exported ONNX models and PyTorch eager mode. @@ -282,6 +309,9 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): inputs=repr(inputs), kwargs=repr(cpu_sample.kwargs), ): + skip_reason = _should_skip_test_sample(op.name, cpu_sample) + if skip_reason is not None: + self.skipTest(skip_reason) input_numpy = [_convert_tensor_to_numpy(x) for x in inputs] torch_output = op(*inputs, **cpu_sample.kwargs) try: From f8385b0354089e438f267f4680ce4324fa3df9b6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 10 Dec 2022 01:16:44 +0000 Subject: [PATCH 06/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like [ghstack-poisoned] --- .../test/function_libs/torch_aten/ops_correctness_test.py | 3 ++- pyproject.toml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 1284da351c..9cb81eb229 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -156,7 +156,8 @@ def wrapped(fn): # Modify this section ########################################################## # Ops to be tested for numerical consistency between onnx and pytorch -OPINFO_FUNCTION_MAPPING = { +# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py +OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { "add": core_ops.aten_add, "mul": core_ops.aten_mul, "nn.functional.elu": nn_ops.aten_elu, diff --git a/pyproject.toml b/pyproject.toml index c0ac8b2d7a..be52d4c11e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] [tool.mypy] follow_imports = "silent" # TODO: Remove when we fix all the mypy errors strict_optional = true -warn_return_any = true warn_no_return = true warn_unused_ignores = true warn_redundant_casts = true From 060f9dbd80447d126282fdf5c1757202f38a447f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 10 Dec 2022 19:35:50 +0000 Subject: [PATCH 07/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned] From 9bb4038aff85c94400891b00e3199dcecdaaad52 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 12 Dec 2022 17:09:58 +0000 Subject: [PATCH 08/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned] From cbfb867df0d2114677e14951eb5fc18ca6a1cbf4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 12 Dec 2022 17:38:56 +0000 Subject: [PATCH 09/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned] --- onnxscript/backend/onnx_backend.py | 2 +- onnxscript/backend/onnx_export.py | 5 ++++- onnxscript/irbuilder.py | 10 +++++----- onnxscript/main.py | 4 +++- onnxscript/test/common/onnx_script_test_case.py | 4 ++-- onnxscript/utils.py | 2 +- pyproject.toml | 6 ++---- requirements-dev.txt | 2 +- 8 files changed, 19 insertions(+), 16 deletions(-) diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index 576dfdd46e..1d917ffdc0 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -74,7 +74,7 @@ def _read_proto_from_file(full): seq = onnx.SequenceProto() try: seq.ParseFromString(serialized) - loaded = to_list(seq) + loaded = to_list(seq) # type: ignore[assignment] except Exception: # pylint: disable=W0703 try: loaded = onnx.load_model_from_string(serialized) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index b19f5e3b06..8104f9aeb7 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -207,7 +207,10 @@ def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None): if hasattr(graph, "initializer"): for init in graph.initializer: node = make_node( - "Constant", [], [self._rename_variable(init.name)], value=init + "Constant", + [], + [self._rename_variable(init.name)], # type: ignore[list-item] + value=init, ) code.append(self._python_make_node(node, opsets, indent=indent)) if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0: diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 8b769e6978..213d1b8264 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -7,7 +7,7 @@ import io import logging import warnings -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence import onnx from onnx import ValueInfoProto, helper @@ -185,7 +185,7 @@ def __init__(self, name: str, domain: str = "") -> None: self.stmts: list[IRStmt] = [] self.attrs: list[str] = [] # attribute parameters self.attr_protos: list[ - onnx.AttributeProto + IRAttributeValue ] = [] # attribute parameters with default value self.called_functions: dict[str, onnx.FunctionProto] = {} self.docstring: str = "" @@ -218,7 +218,7 @@ def append_input(self, name: IRVar) -> None: def append_output(self, name: IRVar) -> None: self.outputs.append(name) - def add_attr_parameter(self, attr: Union[str, IRAttributeValue]) -> None: + def add_attr_parameter(self, attr: str | IRAttributeValue) -> None: if isinstance(attr, IRAttributeValue): self.attr_protos.append(attr) else: @@ -324,7 +324,7 @@ def to_proto(f): def to_graph_and_functions( self, use_default_type: bool = True - ) -> Tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: + ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: """Converts this instance into a `onnx.GraphProto` and a map from function-name to `onnx.FunctionProto`. @@ -360,7 +360,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) return graph - def get_opset_import(self) -> Dict[str, int]: + def get_opset_import(self) -> dict[str, int]: func_opset_imports = {} for s in self.stmts: if s.callee.opset.domain not in func_opset_imports: diff --git a/onnxscript/main.py b/onnxscript/main.py index 7b95cd44ed..98a3c1fb7e 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -54,7 +54,9 @@ def script_check(f: ast.FunctionDef, opset, global_names, source, default_opset= def script( - opset: Optional[values.Opset] = None, default_opset: Optional[Any] = None, **kwargs: Any + opset: Optional[values.Opset] = None, + default_opset: Optional[values.Opset] = None, + **kwargs: Any, ) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]: """Main decorator. Declares a function as an onnx function. diff --git a/onnxscript/test/common/onnx_script_test_case.py b/onnxscript/test/common/onnx_script_test_case.py index 114c25c5a7..08028c8149 100644 --- a/onnxscript/test/common/onnx_script_test_case.py +++ b/onnxscript/test/common/onnx_script_test_case.py @@ -54,11 +54,11 @@ def setUpClass(cls): try: # experimental version # pylint: disable=no-value-for-parameter - cls.all_test_cases = node_test.collect_testcases() # type: ignore[attr-defined] + cls.all_test_cases = node_test.collect_testcases() # type: ignore[attr-defined,call-arg] # pylint: enable=no-value-for-parameter except TypeError: # official version - cls.all_test_cases = node_test.collect_testcases(None) # type: ignore[attr-defined] + cls.all_test_cases = node_test.collect_testcases(None) # type: ignore[attr-defined,arg-type] def _create_model_from_param( self, param: FunctionTestParams, onnx_case_model: onnx.ModelProto diff --git a/onnxscript/utils.py b/onnxscript/utils.py index 7a6cf14ea4..54f3ace6cb 100644 --- a/onnxscript/utils.py +++ b/onnxscript/utils.py @@ -20,7 +20,7 @@ from onnx.printer import to_text as proto2text except ImportError: - def proto2text(x): # pylint: disable=unused-argument + def proto2text(_: Any) -> str: return "" diff --git a/pyproject.toml b/pyproject.toml index be52d4c11e..0e529617d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,18 +50,16 @@ filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] follow_imports = "silent" # TODO: Remove when we fix all the mypy errors strict_optional = true warn_no_return = true -warn_unused_ignores = true +warn_unused_ignores = false # TODO: CI and local are inconsistent when this is true. Investigate. warn_redundant_casts = true warn_incomplete_stub = true # TODO disallow_untyped_calls = true check_untyped_defs = true -disallow_any_generics = true -no_implicit_optional = true +disallow_any_generics = false # TODO disallow_incomplete_defs = true # TODO disallow_subclassing_any = true disallow_untyped_decorators = true warn_unused_configs = true -show_error_codes = true show_column_numbers = true [[tool.mypy.overrides]] diff --git a/requirements-dev.txt b/requirements-dev.txt index a2c46de42c..39c692f76a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,7 +25,7 @@ pytest-subtests pytest-xdist pytest!=7.1.0 pyyaml -torch +torch>=1.13 # Lint lintrunner From 27008e1c9b15a2afd87b61101371f55920e49cf4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 12 Dec 2022 18:06:56 +0000 Subject: [PATCH 10/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned] --- onnxscript/irbuilder.py | 4 +++- onnxscript/utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 213d1b8264..fe9a20d518 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -472,5 +472,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut a = onnx.AttributeProto() a.name = attrname a.ref_attr_name = refname - a.type = ta.pytype_to_attrtype(pytype) + type_ = ta.pytype_to_attrtype(pytype) + assert type_ is not None + a.type = type_ return IRAttributeValue(a) diff --git a/onnxscript/utils.py b/onnxscript/utils.py index 54f3ace6cb..988016bc79 100644 --- a/onnxscript/utils.py +++ b/onnxscript/utils.py @@ -20,7 +20,7 @@ from onnx.printer import to_text as proto2text except ImportError: - def proto2text(_: Any) -> str: + def proto2text(_: Any) -> str: # type: ignore[misc] return "" From c5871c8f77fc3ca17e61c3fa323457791b41b630 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 12 Dec 2022 18:42:25 +0000 Subject: [PATCH 11/18] Update base for Update on "feat(atenlib): implement aten functions 1/n" Implemented ops - lt - gt - round - clamp_min - clamp_max - clamp - repeat - ones_like Test: - Create the ability to skip particular sub tests - Create a helper function to transform torch inputs into numpy for onnxscript to run on [ghstack-poisoned] --- azure-pipelines.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 715429b904..bb5fa36304 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -25,12 +25,13 @@ steps: python -m pip install -r requirements-dev.txt displayName: 'Install dependencies' +# TODO(#249): Fix tests for onnx 1.13 - script: | if [ '$(onnx.standard)' == '1' ] then python -m pip uninstall -y onnx-function-experiment python -m pip uninstall -y ort-function-experiment-nightly - python -m pip install onnx onnxruntime + python -m pip install onnx==1.12 onnxruntime fi displayName: 'Install onnx' From 691772b33844946d854851c3bc4f2273aad23da4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Dec 2022 16:40:19 +0000 Subject: [PATCH 12/18] feat(atenlib): ops 2/n [ghstack-poisoned] --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +++++-- .../test/function_libs/torch_aten/ops_correctness_test.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 377e25eb86..477275310f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3469,10 +3469,13 @@ def aten_numpy_T(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_ones(size: INT64) -> TensorType: +def aten_ones(size: INT64, dtype: int = -1) -> TensorType: # ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - raise NotImplementedError() + one = op.Constant(value_float=1) + if dtype != -1: + one = op.Cast(one, to=dtype) # type: ignore[arg-type] + return op.Expand(one, size) # type: ignore[arg-type] def aten_ones_like(self, dtype: int = -1): diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index a2a9dbac7c..988249ecbe 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -171,6 +171,7 @@ def wrapped(fn): "nn.functional.elu": nn_ops.aten_elu, "nn.functional.relu6": nn_ops.aten_relu6, "nn.functional.selu": core_ops.aten_selu, + "ones": core_ops.aten_ones, "ones_like": core_ops.aten_ones_like, "repeat": core_ops.aten_repeat, "round": core_ops.aten_round, From f160dfa5e2fe3c682e69fe9e5426cf5da49cee25 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Dec 2022 17:19:59 +0000 Subject: [PATCH 13/18] Update on "feat(atenlib): ops 2/n" [ghstack-poisoned] --- .../function_libs/torch_aten/ops/core.py | 21 ++++++++++++------- .../torch_aten/ops_correctness_test.py | 11 +++++++++- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 477275310f..8c6d09f318 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3463,12 +3463,6 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType: raise NotImplementedError() -def aten_numpy_T(self: TensorType) -> TensorType: - # numpy_T(Tensor(a) self) -> Tensor(a) - - raise NotImplementedError() - - def aten_ones(size: INT64, dtype: int = -1) -> TensorType: # ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -4464,7 +4458,13 @@ def aten_symeig( def aten_t(self: TensorType) -> TensorType: # t(Tensor(a) self) -> Tensor(a) - raise NotImplementedError() + # TODO(justinchuby): Make rank a function + rank = op.Shape(op.Shape(self)) + if rank == 0 or rank == 1: + result = self + else: + result = op.Transpose(self, perm=[1, 0]) + return result def aten_t_copy(self: TensorType) -> TensorType: @@ -4609,6 +4609,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() +def aten_transpose(self, dim0: int, dim1: int): + # transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + + # FIXME(justinchuby): onnxscript raises Unsupported expression type + return op.Transpose(self, [dim0, dim1]) + + def aten_triangular_solve( self: TensorType, A: TensorType, diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 988249ecbe..fb298a9518 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -162,7 +162,7 @@ def wrapped(fn): # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = { "add": core_ops.aten_add, - # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable + "clamp": core_ops.aten_clamp, "clamp_max": core_ops.aten_clamp_max_tensor, "clamp_min": core_ops.aten_clamp_min_tensor, "gt": core_ops.aten_gt, @@ -176,12 +176,15 @@ def wrapped(fn): "repeat": core_ops.aten_repeat, "round": core_ops.aten_round, "sub": core_ops.aten_sub, + "t": core_ops.aten_t, + # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed } TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) EXPECTED_SKIPS_OR_FAILS = ( xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"), + skip("clamp", reason="Enable when onnxscript errors are fixed"), xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"), xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"), xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"), @@ -214,6 +217,7 @@ def wrapped(fn): "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), + xfail("transpose", reason="Enable when onnxscript errors are fixed"), ) @@ -241,6 +245,11 @@ def wrapped(fn): OPS_DB = copy.deepcopy(common_methods_invocations.op_db) +ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) +# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB +for op_name in OPINFO_FUNCTION_MAPPING: + assert op_name in ALL_OPS_IN_DB, f"{op_name} not in OPS_DB" + TORCH_TYPE_TO_ONNX = { torch.bool: onnx.TensorProto.BOOL, From e8f07c9148b7db2b0a583b70a6501e2d198e783c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 Dec 2022 21:00:28 +0000 Subject: [PATCH 14/18] Update base for Update on "feat(atenlib): ops 2/n" [ghstack-poisoned] From 7c0e3055c4359ff22829824c1ff69f0937b84763 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 14 Dec 2022 00:03:20 +0000 Subject: [PATCH 15/18] Update base for Update on "feat(atenlib): ops 2/n" [ghstack-poisoned] From b7b03eea79254e96fac0208eb573b7779d507a68 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 14 Dec 2022 00:04:26 +0000 Subject: [PATCH 16/18] Update base for Update on "feat(atenlib): ops 2/n" [ghstack-poisoned] From 93ed77b1d26673434db3b815ee8a53c54ec27c81 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 14 Dec 2022 18:36:02 +0000 Subject: [PATCH 17/18] Update base for Update on "feat(atenlib): ops 2/n" [ghstack-poisoned] From 57676c8063f426cb85f38cb250f05c4c4fc6404e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 Dec 2022 05:49:05 +0000 Subject: [PATCH 18/18] Update base for Update on "feat(atenlib): ops 2/n" [ghstack-poisoned]