Skip to content

Commit cfa360b

Browse files
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
2 parents 8826d63 + 1c037f4 commit cfa360b

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3040,7 +3040,7 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
30403040
return op.Shape(broadcasted)
30413041

30423042

3043-
@torch_op("aten::index", trace_only=True)
3043+
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
30443044
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
30453045
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
30463046

0 commit comments

Comments
 (0)