Skip to content

Commit 68d4b9f

Browse files
authored
Add Op (Slice - complex) | feat torchlib (#2089)
Fix pytorch/pytorch#147896
1 parent 4c1cda2 commit 68d4b9f

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7698,6 +7698,21 @@ def aten_sinh(self: TFloat) -> TFloat:
76987698
return op.Sinh(self)
76997699

77007700

7701+
@torch_op(("aten::slice.Tensor"), trace_only=True, complex=True)
7702+
def aten_slice_complex(
7703+
self: TTensor,
7704+
dim: int = 0,
7705+
start: Optional[INT64] = None,
7706+
end: Optional[INT64] = None,
7707+
step: Optional[INT64] = None,
7708+
) -> TTensor:
7709+
"""slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)"""
7710+
if dim < 0:
7711+
# Account for the complex dimension in ONNX
7712+
dim = dim - 1
7713+
return aten_slice(self, dim, start, end, step)
7714+
7715+
77017716
@torch_op(("aten::slice.Tensor"), trace_only=True)
77027717
def aten_slice(
77037718
self: TTensor,

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,6 +2048,7 @@ def _where_input_wrangler(
20482048
),
20492049
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter),
20502050
TorchLibOpInfo("slice", core_ops.aten_slice),
2051+
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
20512052
TorchLibOpInfo(
20522053
"sum",
20532054
core_ops.aten_sum_dim_IntList,

0 commit comments

Comments
 (0)