Skip to content

Commit 7c26905

Browse files
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
2 parents 7eadba1 + bda7e1c commit 7c26905

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
472472
sequence_input.append(input)
473473
ort_inputs[input_name] = subarg
474474
else:
475+
# Include non-numpy inputs as-is
476+
# For example, it could be a None value that we want to keep
475477
sequence_input.append(subarg)
476478
onnxscript_args.append(sequence_input)
477479
else:

0 commit comments

Comments
 (0)