Commit 9d941c5
committed
Update on "Implement aten::index | feat(torchlib) (#862)"
---
**This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.**
## Understanding `aten::index`
For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
```
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed in the process.
## Gather op reference
- https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
- https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/
---------
Co-authored-by: BowenBao <bowbaomicrosoft.com>
[ghstack-poisoned]1 file changed
Lines changed: 6 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
460 | 460 | | |
461 | 461 | | |
462 | 462 | | |
| 463 | + | |
463 | 464 | | |
464 | 465 | | |
465 | 466 | | |
| |||
468 | 469 | | |
469 | 470 | | |
470 | 471 | | |
| 472 | + | |
471 | 473 | | |
472 | 474 | | |
473 | 475 | | |
| |||
476 | 478 | | |
477 | 479 | | |
478 | 480 | | |
| 481 | + | |
479 | 482 | | |
480 | 483 | | |
481 | 484 | | |
482 | 485 | | |
483 | 486 | | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
484 | 490 | | |
485 | 491 | | |
486 | 492 | | |
| |||
0 commit comments