Commit 8826d63
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
---
**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**
## Understanding `aten::index`
For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.
## Gather op reference
- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/
---------
Co-authored-by: BowenBao <bowbaomicrosoft.com>
[ghstack-poisoned]2 files changed
Lines changed: 3 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
660 | 660 | | |
661 | 661 | | |
662 | 662 | | |
663 | | - | |
| 663 | + | |
| 664 | + | |
664 | 665 | | |
665 | 666 | | |
666 | 667 | | |
667 | | - | |
668 | | - | |
669 | 668 | | |
670 | 669 | | |
671 | 670 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
607 | 607 | | |
608 | 608 | | |
609 | 609 | | |
610 | | - | |
| 610 | + | |
611 | 611 | | |
612 | 612 | | |
613 | 613 | | |
| |||
0 commit comments