Skip to content

Commit edfa7c1

Browse files
add ops: slice (#304)
Co-authored-by: Justin Chu <[email protected]>
1 parent cf27ba8 commit edfa7c1

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4409,16 +4409,30 @@ def aten_sinh(self: TFloat) -> TFloat:
44094409
return op.Sinh(self)
44104410

44114411

4412+
@torch_op("aten::slice")
44124413
def aten_slice(
4413-
self: TensorType,
4414+
self: TTensor,
44144415
dim: int = 0,
44154416
start: Optional[INT64] = None,
44164417
end: Optional[INT64] = None,
44174418
step: INT64 = 1,
4418-
) -> TensorType:
4419+
) -> TTensor:
44194420
# slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
44204421

4421-
raise NotImplementedError()
4422+
# TODO: using OptionalHasElement() to check start/end value
4423+
start = op.Cast(start, to=INT64.dtype)
4424+
start = op.Reshape(start, op.Constant(value_ints=[-1]))
4425+
4426+
end = op.Cast(end, to=INT64.dtype)
4427+
end = op.Reshape(end, op.Constant(value_ints=[-1]))
4428+
4429+
dim = op.Cast(dim, to=INT64.dtype)
4430+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
4431+
4432+
step = op.Cast(step, to=INT64.dtype)
4433+
step = op.Reshape(step, op.Constant(value_ints=[-1]))
4434+
4435+
return op.Slice(self, start, end, dim, step)
44224436

44234437

44244438
def aten_slice_backward(

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def _topk_input_wrangler(
314314
"sign": core_ops.aten_sign,
315315
"sin": core_ops.aten_sin,
316316
"sinh": core_ops.aten_sinh,
317+
"slice": core_ops.aten_slice,
317318
"sqrt": core_ops.aten_sqrt,
318319
"sub": core_ops.aten_sub,
319320
"t": core_ops.aten_t,
@@ -438,6 +439,12 @@ def _topk_input_wrangler(
438439
matcher=lambda sample: "scale_factor" in sample.kwargs,
439440
reason="fixme: the scale_factor tests",
440441
),
442+
skip(
443+
"slice",
444+
# kwargs {dim, start, end, step} is empty, we cannot give the default value
445+
matcher=lambda sample: len(sample.kwargs) == 0,
446+
reason="start and end must be 1-D array, cannot be optional, due to ort 1.13 does not support yet",
447+
),
441448
)
442449

443450
duplicate_opinfo(

0 commit comments

Comments
 (0)