From 913f230237153a60b9457d79b30fae046b9e6d26 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 19 Jan 2023 11:55:46 +0800 Subject: [PATCH 1/2] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index e7cd5d764f..d19b3a2301 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4549,10 +4549,25 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::split") def aten_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: # split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] - raise NotImplementedError() + return op.SplitToSequence(self, split_size, axis=dim) + + +def test_aten_split(): + import numpy as np + a = np.arange(10, dtype=np.float32).reshape(5,2) + print(a) + s = 2 + dim = 0 + b = aten_split(a, s, dim=dim) + print(b) + print("------------------") + +test_aten_split() + def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: From 41234a23ea66beb9c73991a497dd24b5d0ae9a64 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 19 Jan 2023 12:30:50 +0800 Subject: [PATCH 2/2] add split --- onnxscript/function_libs/torch_aten/ops/core.py | 16 +--------------- .../torch_aten/ops_correctness_test.py | 9 +++++++++ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index d19b3a2301..ab60c8d994 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4550,26 +4550,12 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType: @torch_op("aten::split") -def aten_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: +def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor: # split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] return op.SplitToSequence(self, split_size, axis=dim) -def test_aten_split(): - import numpy as np - a = np.arange(10, dtype=np.float32).reshape(5,2) - print(a) - s = 2 - dim = 0 - b = aten_split(a, s, dim=dim) - print(b) - print("------------------") - -test_aten_split() - - - def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: # split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index c623aab171..91ef7839ac 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -224,6 +224,14 @@ def _softmax_input_wrangler( return args, kwargs +def _split_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if len(args) >= 3: + kwargs["dim"] = args.pop(2) + return args, kwargs + + def _topk_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -345,6 +353,7 @@ def _topk_input_wrangler( "sinh": core_ops.aten_sinh, "slice": core_ops.aten_slice, "softmax": (special_ops.aten_special_softmax, _softmax_input_wrangler), + "split": (core_ops.aten_split, _split_input_wrangler), "sqrt": core_ops.aten_sqrt, "sub": core_ops.aten_sub, "t": core_ops.aten_t,