Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,10 +2174,15 @@ def aten_index_reduce(
raise NotImplementedError()


def aten_index_select(self: TensorType, dim: int, index: TensorType) -> TensorType:
@torch_op("aten::index_select")
def aten_index_select(self: TTensor, dim: int, index: TInt) -> TTensor:
# index_select(Tensor self, int dim, Tensor index) -> Tensor

raise NotImplementedError()
# Index can be a scalar. Reshape it to a rank 1 tensor.
index = op.Reshape(index, (-1,))
index = op.Cast(index, to=INT64.dtype)

return op.Gather(self, index, axis=dim)


def aten_index_select_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def wrapped(fn):
"exp2": core_ops.aten_exp2,
"fmod": core_ops.aten_fmod,
"gt": core_ops.aten_gt,
"index_select": core_ops.aten_index_select,
"isinf": core_ops.aten_isinf,
"lt": core_ops.aten_lt,
"matmul": core_ops.aten_matmul,
Expand Down