Skip to content

Commit 8826d63

Browse files
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
2 parents 7c26905 + 64ed9f9 commit 8826d63

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,11 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
660660
supports_out=False,
661661
),
662662
opinfo_core.OpInfo(
663-
"aten.index.Tensor",
663+
"ops.aten.index.Tensor",
664+
aten_name="index.Tensor",
664665
dtypes=common_dtype.all_types_and_complex_and(
665666
torch.bool, torch.float16, torch.bfloat16, torch.chalf
666667
),
667-
aten_name="index",
668-
op=torch.ops.aten.index.Tensor,
669668
sample_inputs_func=sample_inputs_index,
670669
),
671670
opinfo_core.OpInfo(

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def _where_input_wrangler(
607607
TorchLibOpInfo("gt", core_ops.aten_gt),
608608
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
609609
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
610-
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index, trace_only=True),
610+
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
611611
TorchLibOpInfo(
612612
"index_put_bool",
613613
core_ops.aten_index_put_bool,

0 commit comments

Comments
 (0)