diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index d1f236d292..a1c4a40245 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4552,10 +4552,11 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType: raise NotImplementedError() -def aten_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: +@torch_op("aten::split") +def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor: # split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] - raise NotImplementedError() + return op.SplitToSequence(self, split_size, axis=dim) def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index b42d516d02..4c05107f9d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -232,6 +232,14 @@ def _softmax_input_wrangler( return args, kwargs +def _split_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if len(args) >= 3: + kwargs["dim"] = args.pop(2) + return args, kwargs + + def _topk_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -353,6 +361,7 @@ def _topk_input_wrangler( "sinh": core_ops.aten_sinh, "slice": core_ops.aten_slice, "softmax": (special_ops.aten_special_softmax, _softmax_input_wrangler), + "split": (core_ops.aten_split, _split_input_wrangler), "sqrt": core_ops.aten_sqrt, "sub": core_ops.aten_sub, "t": core_ops.aten_t,