Commit 2e0aee7
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #883
---
**This change implements the logic for `aten::index` and adds tests for
different nd index combinations and permutations.**
## Understanding `aten::index`
For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be
translated to
```
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```
Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None,
None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to
`arg0[0, :, 1:2, tensor([[4,5]]), :]`
None in `indices` are like fillers for dimensions that cannot be removed
in the process.
## Gather op reference
-
https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather
-
https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/
---------
Co-authored-by: BowenBao <bowbao@microsoft.com>
1 parent bbdbf1e commit 2e0aee7
4 files changed
Lines changed: 181 additions & 9 deletions
File tree
- onnxscript
- function_libs/torch_lib/ops
- tests/function_libs/torch_lib
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3020 | 3020 | | |
3021 | 3021 | | |
3022 | 3022 | | |
3023 | | - | |
3024 | | - | |
| 3023 | + | |
| 3024 | + | |
| 3025 | + | |
| 3026 | + | |
3025 | 3027 | | |
3026 | | - | |
| 3028 | + | |
| 3029 | + | |
| 3030 | + | |
| 3031 | + | |
| 3032 | + | |
| 3033 | + | |
| 3034 | + | |
| 3035 | + | |
| 3036 | + | |
| 3037 | + | |
| 3038 | + | |
| 3039 | + | |
| 3040 | + | |
| 3041 | + | |
| 3042 | + | |
| 3043 | + | |
| 3044 | + | |
| 3045 | + | |
| 3046 | + | |
| 3047 | + | |
| 3048 | + | |
| 3049 | + | |
| 3050 | + | |
| 3051 | + | |
| 3052 | + | |
| 3053 | + | |
| 3054 | + | |
| 3055 | + | |
| 3056 | + | |
| 3057 | + | |
| 3058 | + | |
| 3059 | + | |
| 3060 | + | |
| 3061 | + | |
| 3062 | + | |
| 3063 | + | |
| 3064 | + | |
| 3065 | + | |
| 3066 | + | |
| 3067 | + | |
| 3068 | + | |
| 3069 | + | |
| 3070 | + | |
| 3071 | + | |
| 3072 | + | |
| 3073 | + | |
| 3074 | + | |
| 3075 | + | |
| 3076 | + | |
| 3077 | + | |
| 3078 | + | |
| 3079 | + | |
| 3080 | + | |
| 3081 | + | |
| 3082 | + | |
| 3083 | + | |
| 3084 | + | |
| 3085 | + | |
| 3086 | + | |
| 3087 | + | |
| 3088 | + | |
| 3089 | + | |
| 3090 | + | |
| 3091 | + | |
| 3092 | + | |
| 3093 | + | |
| 3094 | + | |
| 3095 | + | |
| 3096 | + | |
| 3097 | + | |
| 3098 | + | |
| 3099 | + | |
| 3100 | + | |
| 3101 | + | |
| 3102 | + | |
| 3103 | + | |
| 3104 | + | |
| 3105 | + | |
| 3106 | + | |
| 3107 | + | |
| 3108 | + | |
| 3109 | + | |
| 3110 | + | |
| 3111 | + | |
| 3112 | + | |
| 3113 | + | |
| 3114 | + | |
| 3115 | + | |
| 3116 | + | |
| 3117 | + | |
| 3118 | + | |
| 3119 | + | |
| 3120 | + | |
| 3121 | + | |
| 3122 | + | |
| 3123 | + | |
| 3124 | + | |
| 3125 | + | |
| 3126 | + | |
| 3127 | + | |
| 3128 | + | |
| 3129 | + | |
| 3130 | + | |
| 3131 | + | |
| 3132 | + | |
| 3133 | + | |
| 3134 | + | |
| 3135 | + | |
| 3136 | + | |
| 3137 | + | |
| 3138 | + | |
3027 | 3139 | | |
3028 | 3140 | | |
3029 | 3141 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
444 | 444 | | |
445 | 445 | | |
446 | 446 | | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
447 | 496 | | |
448 | 497 | | |
449 | 498 | | |
| |||
616 | 665 | | |
617 | 666 | | |
618 | 667 | | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
619 | 676 | | |
620 | 677 | | |
621 | 678 | | |
| |||
Lines changed: 8 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
261 | 261 | | |
262 | 262 | | |
263 | 263 | | |
264 | | - | |
| 264 | + | |
| 265 | + | |
265 | 266 | | |
266 | 267 | | |
267 | 268 | | |
| |||
276 | 277 | | |
277 | 278 | | |
278 | 279 | | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
| 280 | + | |
283 | 281 | | |
284 | 282 | | |
285 | 283 | | |
| |||
473 | 471 | | |
474 | 472 | | |
475 | 473 | | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
476 | 478 | | |
477 | 479 | | |
478 | 480 | | |
| |||
515 | 517 | | |
516 | 518 | | |
517 | 519 | | |
518 | | - | |
| 520 | + | |
519 | 521 | | |
520 | 522 | | |
521 | 523 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
607 | 607 | | |
608 | 608 | | |
609 | 609 | | |
| 610 | + | |
610 | 611 | | |
611 | 612 | | |
612 | 613 | | |
| |||
0 commit comments