Skip to content

Commit e07a37d

Browse files
authored
feat(atenlib): add ops(split) (#337)
1 parent 4dc2a4d commit e07a37d

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4552,10 +4552,11 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType:
45524552
raise NotImplementedError()
45534553

45544554

4555-
def aten_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
4555+
@torch_op("aten::split")
4556+
def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
45564557
# split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
45574558

4558-
raise NotImplementedError()
4559+
return op.SplitToSequence(self, split_size, axis=dim)
45594560

45604561

45614562
def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def _softmax_input_wrangler(
232232
return args, kwargs
233233

234234

235+
def _split_input_wrangler(
236+
args: list[Any], kwargs: dict[str, Any]
237+
) -> tuple[list[Any], dict[str, Any]]:
238+
if len(args) >= 3:
239+
kwargs["dim"] = args.pop(2)
240+
return args, kwargs
241+
242+
235243
def _topk_input_wrangler(
236244
args: list[Any], kwargs: dict[str, Any]
237245
) -> tuple[list[Any], dict[str, Any]]:
@@ -353,6 +361,7 @@ def _topk_input_wrangler(
353361
"sinh": core_ops.aten_sinh,
354362
"slice": core_ops.aten_slice,
355363
"softmax": (special_ops.aten_special_softmax, _softmax_input_wrangler),
364+
"split": (core_ops.aten_split, _split_input_wrangler),
356365
"sqrt": core_ops.aten_sqrt,
357366
"sub": core_ops.aten_sub,
358367
"t": core_ops.aten_t,

0 commit comments

Comments
 (0)