From 0cc23506a4f958a4e815bcded36437e5de929ecd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 5 Jan 2023 00:43:32 +0000 Subject: [PATCH 1/7] feat(atenlib): index_select --- onnxscript/function_libs/torch_aten/ops/core.py | 16 ++++++++++++++-- .../torch_aten/ops_correctness_test.py | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index dd059158a7..08352c2e65 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Sequence, Union from onnxscript import BOOL, DOUBLE, FLOAT, INT64 +import onnx.helper from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import ( TFloat, @@ -2174,10 +2175,21 @@ def aten_index_reduce( raise NotImplementedError() -def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType: +@torch_op("aten::index_select") +def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor - raise NotImplementedError() + if op.Size(op.Shape(index)) == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = op.Expand( + index, + op.Constant( + value=onnx.helper.make_tensor("size_one", INT64.dtype, [1], [1]) + ), + ) + + index = op.Cast(index, to=INT64.dtype) + return op.Gather(self, index, axis=dim) def aten_index_select_backward( diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index cb4aeaadd0..af77aa6994 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -170,6 +170,7 @@ def wrapped(fn): "exp2": core_ops.aten_exp2, "fmod": core_ops.aten_fmod, "gt": core_ops.aten_gt, + "index_select": core_ops.aten_index_select, "isinf": core_ops.aten_isinf, "lt": core_ops.aten_lt, "matmul": core_ops.aten_matmul, From 2891c6c9e95748ab1299cc8220b39f8b7acf8e59 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 5 Jan 2023 00:46:38 +0000 Subject: [PATCH 2/7] format --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 08352c2e65..9191b5e5af 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -13,8 +13,9 @@ from typing import Any, Optional, Sequence, Union -from onnxscript import BOOL, DOUBLE, FLOAT, INT64 import onnx.helper + +from onnxscript import BOOL, DOUBLE, FLOAT, INT64 from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import ( TFloat, @@ -2183,9 +2184,7 @@ def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # Index is a scalar. Reshape it to a size 1 tensor. index = op.Expand( index, - op.Constant( - value=onnx.helper.make_tensor("size_one", INT64.dtype, [1], [1]) - ), + op.Constant(value=onnx.helper.make_tensor("size_one", INT64.dtype, [1], [1])), ) index = op.Cast(index, to=INT64.dtype) From f908cf508a82bdab5d338a5374f45d923b807815 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 5 Jan 2023 15:40:28 -0800 Subject: [PATCH 3/7] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 9191b5e5af..927c0f0470 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -13,8 +13,6 @@ from typing import Any, Optional, Sequence, Union -import onnx.helper - from onnxscript import BOOL, DOUBLE, FLOAT, INT64 from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import ( @@ -2180,14 +2178,10 @@ def aten_index_reduce( def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor - if op.Size(op.Shape(index)) == 0: - # Index is a scalar. Reshape it to a size 1 tensor. - index = op.Expand( - index, - op.Constant(value=onnx.helper.make_tensor("size_one", INT64.dtype, [1], [1])), - ) - + # Index can be a scalar. Reshape it to a rank 1 tensor. + index = op.Reshape(index, (-1,)) index = op.Cast(index, to=INT64.dtype) + return op.Gather(self, index, axis=dim) From dc233359b48beb87c4054b9e3bd7ed024dbc6bd8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Jan 2023 01:06:42 +0000 Subject: [PATCH 4/7] Disable index_select compilation --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +++++-- .../test/function_libs/torch_aten/ops_correctness_test.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 927c0f0470..747b371c9e 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2174,12 +2174,15 @@ def aten_index_reduce( raise NotImplementedError() -@torch_op("aten::index_select") +# @torch_op("aten::index_select") def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor + if op.Size(op.Shape(self)) == 0: + return self + # Index can be a scalar. Reshape it to a rank 1 tensor. - index = op.Reshape(index, (-1,)) + index = op.Reshape(index, op.Constant(value_floats=[-1])) index = op.Cast(index, to=INT64.dtype) return op.Gather(self, index, axis=dim) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index af77aa6994..dab8fcd4d3 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -366,6 +366,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None + if not isinstance(function_output, np.ndarray): + # An onnxscript tensor + function_output = function_output.value + # Use torch testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output), From df6f875e31471d875d5b78677a11323292990ef9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Jan 2023 01:58:02 +0000 Subject: [PATCH 5/7] Trace only ops --- .../function_libs/torch_aten/ops/core.py | 16 ++++++---- .../function_libs/torch_aten/registration.py | 30 ++++++++++++++----- .../torch_aten/ops_correctness_test.py | 3 +- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 24019cd8a2..1446e2e212 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2165,18 +2165,20 @@ def aten_index_reduce( raise NotImplementedError() -# @torch_op("aten::index_select") +@torch_op("aten::index_select", trace_only=True) # FIXME(#277): Script when attributes can come before inputs def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor if op.Size(op.Shape(self)) == 0: - return self + result = self + else: + # Index can be a scalar. Reshape it to a rank 1 tensor. + index = op.Reshape(index, op.Constant(value_floats=[-1])) + index = op.Cast(index, to=INT64.dtype) - # Index can be a scalar. Reshape it to a rank 1 tensor. - index = op.Reshape(index, op.Constant(value_floats=[-1])) - index = op.Cast(index, to=INT64.dtype) + result = op.Gather(self, index, axis=dim) - return op.Gather(self, index, axis=dim) + return result def aten_index_select_backward( @@ -4678,10 +4680,12 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() +@torch_op("aten::transpose", trace_only=True) def aten_transpose(self, dim0: int, dim1: int): # transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) # FIXME(justinchuby): onnxscript raises Unsupported expression type + # Script the function when this is fixed return op.Transpose(self, [dim0, dim1]) diff --git a/onnxscript/function_libs/torch_aten/registration.py b/onnxscript/function_libs/torch_aten/registration.py index 01c75b9f4e..86ef7b4835 100644 --- a/onnxscript/function_libs/torch_aten/registration.py +++ b/onnxscript/function_libs/torch_aten/registration.py @@ -48,19 +48,33 @@ def __repr__(self): def torch_op( - name, overload: bool = False, registry: Optional[Registry] = None -) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]: - """Register a torch op.""" + name, + *, + overload: bool = False, + registry: Optional[Registry] = None, + trace_only: bool = False, +) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]: + """Register a torch op. + + Args: + name: ATen name of the function. E.g. "aten::add". + overload: Whether the function is an overload (not default). + registry: Registry to register the function to. If None, the default registry is used. + trace_only: Whether the function should only be traced and not compiled. + """ if registry is None: registry = default_registry - def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction: + def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]: - # Compile the function - compiled = onnxscript.script()(func) + if trace_only: + processed_func = func + else: + # Compile the function + processed_func = onnxscript.script()(func) assert registry is not None - registry.register(compiled, name, overload=overload) - return compiled + registry.register(processed_func, name, overload=overload) + return processed_func return wrapper diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index b92283aa67..406328c1ec 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -182,6 +182,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: OPINFO_FUNCTION_MAPPING: dict[ str, onnxscript.OnnxFunction + | Callable[..., Any] | tuple[ onnxscript.OnnxFunction | Callable[..., Any], Callable[[dict[str, Any]], dict[str, Any]], @@ -253,7 +254,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, - # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed, + "transpose": core_ops.aten_transpose, "unsqueeze": core_ops.aten_unsqueeze, "where": core_ops.aten_where, "zeros": core_ops.aten_zeros, From 1a7fa36bda1475958a447837bcf9fabd6f44addb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Jan 2023 01:59:35 +0000 Subject: [PATCH 6/7] format --- onnxscript/function_libs/torch_aten/ops/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 1446e2e212..701b5e4e2c 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2165,7 +2165,8 @@ def aten_index_reduce( raise NotImplementedError() -@torch_op("aten::index_select", trace_only=True) # FIXME(#277): Script when attributes can come before inputs +# FIXME(#277): Script when attributes can come before inputs +@torch_op("aten::index_select", trace_only=True) def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor: # index_select(Tensor self, int dim, Tensor index) -> Tensor From e6b8a9f847622ca508d1cf805d8881f8beb8b691 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Jan 2023 02:25:38 +0000 Subject: [PATCH 7/7] Fix transpose --- onnxscript/function_libs/torch_aten/ops/core.py | 15 ++++++++++++++- .../torch_aten/ops_correctness_test.py | 1 - 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 701b5e4e2c..90b1ae453e 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4687,7 +4687,20 @@ def aten_transpose(self, dim0: int, dim1: int): # FIXME(justinchuby): onnxscript raises Unsupported expression type # Script the function when this is fixed - return op.Transpose(self, [dim0, dim1]) + self_rank = op.Size(op.Shape(self)) + + if self_rank == 0: + result = self + else: + # Python code, change when onnxscript supports this + self_rank_val = self_rank.value # type: ignore[attr-defined] + dims = list(range(self_rank_val)) + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + # Python code ends + + result = op.Transpose(self, perm=dims) + + return result def aten_triangular_solve( diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 406328c1ec..d5d167ab83 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -278,7 +278,6 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]: xfail("round", variant_name="decimals_0", reason="The op does not support decimals"), xfail("round", variant_name="decimals_3", reason="The op does not support decimals"), xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"), - xfail("transpose", reason="Enable when onnxscript errors are fixed"), )