Skip to content

Commit 7eadba1

Browse files
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
1 parent 5706948 commit 7eadba1

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def _where_input_wrangler(
607607
TorchLibOpInfo("gt", core_ops.aten_gt),
608608
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
609609
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
610-
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index),
610+
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index, trace_only=True),
611611
TorchLibOpInfo(
612612
"index_put_bool",
613613
core_ops.aten_index_put_bool,

0 commit comments

Comments
 (0)