Skip to content

Commit 9d941c5

Browse files
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
2 parents cfa360b + 169debc commit 9d941c5

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
460460
([None, None, None, index_1d],),
461461
([index_1d, None],),
462462
([index_1d, None, None],),
463+
# Extra index
463464
([None, index_1d, None, index_1d],),
464465
([index_1d, None, index_1d, None],),
465466
([None, index_1d, index_1d, None],),
@@ -468,6 +469,7 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
468469
([None, None, None, index_2d],),
469470
([index_2d, None],),
470471
([index_2d, None, None],),
472+
# Extra index
471473
([None, index_2d, None, index_2d],),
472474
([index_2d, None, index_2d, None],),
473475
([None, index_2d, index_2d, None],),
@@ -476,11 +478,15 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
476478
([None, None, None, index_3d],),
477479
([index_3d, None],),
478480
([index_3d, None, None],),
481+
# Extra index
479482
([None, index_3d, None, index_3d],),
480483
([index_3d, None, index_3d, None],),
481484
([None, index_3d, index_3d, None],),
482485
# Mixed indices
483486
([None, index_3d, index_1d, index_2d],),
487+
# All indices are not None
488+
([index_2d, index_3d, index_1d],),
489+
([index_2d, index_3d, index_1d, index_2d],),
484490
]
485491

486492
for args in test_args:

0 commit comments

Comments
 (0)