Skip to content

Implement aten::index | feat(torchlib) (#862) #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 17, 2023
Merged

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jul 17, 2023

Stack from ghstack (oldest at bottom):


This change implements the logic for aten::index and adds tests for different nd index combinations and permutations.

Understanding aten::index

For arg0 with shape [7, 3, 4, 5, 6]
The indexing operation arg0[0, :, 1:2, tensor([[4,5]])] will be translated to

+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);

Here,

  • indices = [None, None, arg1] is equivalent to indices = [None, None, arg1, None]
  • The operation arg0[0, :, 1:2, tensor([[4,5]])] is equivalent to arg0[0, :, 1:2, tensor([[4,5]]), :]
    None in indices are like fillers for dimensions that cannot be removed in the process.

Gather op reference


Co-authored-by: BowenBao [email protected]

@justinchuby justinchuby marked this pull request as ready for review July 17, 2023 17:47
justinchuby added a commit that referenced this pull request Jul 17, 2023
justinchuby added a commit that referenced this pull request Jul 17, 2023
@justinchuby justinchuby added module: torchlib Related to the torch/aten function lib in development change base before merge Remember to change the merge base to main when the PR is ready to merge labels Jul 17, 2023
@justinchuby justinchuby requested review from xiaowuhu and fatcat-z July 17, 2023 18:00
---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jul 17, 2023
Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 LGTM. Do you plan to add bool mask indices support in this or another PR?

@@ -473,6 +471,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
input.value = subarg
sequence_input.append(input)
ort_inputs[input_name] = subarg
else:
sequence_input.append(subarg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put comment here explaining why this is needed in case we forget?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider elif None? It's easier to catch what we are not expecting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I wonder if there are things we don't expect that can sneak in? Nested lists?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. Just in case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. Created #892 for now.

@justinchuby
Copy link
Collaborator Author

🎉 LGTM. Do you plan to add bool mask indices support in this or another PR?

I plan to do it in another PR to keep this one manageable.

@codecov
Copy link

codecov bot commented Jul 17, 2023

Codecov Report

Merging #883 (169debc) into main (bbdbf1e) will decrease coverage by 0.09%.
The diff coverage is n/a.

❗ Current head 169debc differs from pull request most recent head 9d941c5. Consider uploading reports for the commit 9d941c5 to get more accurate results

@@            Coverage Diff             @@
##             main     #883      +/-   ##
==========================================
- Coverage   76.59%   76.51%   -0.09%     
==========================================
  Files         112      112              
  Lines       13408    13408              
  Branches     1348     1348              
==========================================
- Hits        10270    10259      -11     
- Misses       2801     2812      +11     
  Partials      337      337              

see 1 file with indirect coverage changes

justinchuby added a commit that referenced this pull request Jul 17, 2023
)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #883
* #882
* __->__ #881

The torchscript ONNX graph generator creates numeric value names by
default (`0`, `1`). These are not legal ONNX tensor names, since ONNX
requires the names to be valid C variable names. This change updates the
names by prepending a prefix `_val_` or `_const_` to make them valid
ONNX names. It also improves readability by making the names less likely
to be confused with shape values.

I decided to use the `_` prefix to reduce the chance of name collision
with FX names.

After:

```
<
   ir_version: 8,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.1.0"
>
torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] _val_10) {
   _val_2 = Transpose <perm = [0, 1, 2, 3]> (input_0)
   _val_3 = Max (input_1_3)
   _val_4 = Shape <start = 0> (_val_3)
   _val_5 = Expand (input_1_3, _val_4)
   _const_6 = Constant <value = int64 {-1}> ()
   _val_7 = Unsqueeze (_val_5, _const_6)
   _val_8 = Concat <axis = -1> (_val_7)
   _val_9 = GatherND <batch_dims = 0> (_val_2, _val_8)
   _val_10 = Transpose <perm = [0, 1, 2, 3]> (_val_9)
}
```

Before:

```
<
   ir_version: 8,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.1.0"
>
torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] 10) {
   2 = Transpose <perm = [0, 1, 2, 3]> (input_0)
   3 = Max (input_1_3)
   4 = Shape <start = 0> (3)
   5 = Expand (input_1_3, 4)
   6 = Constant <value = int64 {-1}> ()
   7 = Unsqueeze (5, 6)
   8 = Concat <axis = -1> (7)
   9 = GatherND <batch_dims = 0> (2, 8)
   10 = Transpose <perm = [0, 1, 2, 3]> (9)
}
```
justinchuby added a commit that referenced this pull request Jul 17, 2023
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #883
* __->__ #882
* #881

Remove the duplicate check_model message when displaying
@justinchuby justinchuby changed the base branch from gh/justinchuby/33/base to main July 17, 2023 18:11
…)"


---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jul 17, 2023
…)"


---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jul 17, 2023
…)"


---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jul 17, 2023
Copy link
Contributor

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

index_1d = common_methods_invocations.index_variable(2, s, device=device)
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device)
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device)
test_args = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would itertools.{product,permutation,combination} help in ensuring all combinations are covered and make the code shorter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me try that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tested with itertools but realized some combinations are invalid so we cannot enumerate them all using itertools. For the sake of clarity I propose that we keep the current explicit tests. I also added more test cases and comments

@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:


def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Converts kwargs to be compatible with ONNX Runtime.

ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my knowledge, does ort support torch.bool now? or was this docstring outdated already?

Copy link
Collaborator Author

@justinchuby justinchuby Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know if ORT supports bool (it should?), but I think this message was a mistake by copilot because we don't actually have this conversion logic as code. If we see issues with ORT I will make adjustments.

index_1d = common_methods_invocations.index_variable(2, s, device=device)
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device)
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device)
test_args = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could itertools.product (or friends) with length 1 to 4 could shorten this listen and ensure no combination is left out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above. Turns out some combinations are invalid to torch and it may be better to specify explicitly.

@justinchuby justinchuby mentioned this pull request Jul 17, 2023
…)"


---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
---

**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**

## Understanding `aten::index`

For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+>  select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+>  slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+>  slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+>  index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.

## Gather op reference

- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/

---------

Co-authored-by: BowenBao <bowbaomicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this pull request Jul 17, 2023
@justinchuby justinchuby merged commit 2e0aee7 into main Jul 17, 2023
@justinchuby justinchuby deleted the gh/justinchuby/33/head branch July 17, 2023 21:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
change base before merge Remember to change the merge base to main when the PR is ready to merge module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants