-
Notifications
You must be signed in to change notification settings - Fork 64
Implement aten::index | feat(torchlib) (#862) #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
Remove the duplicate check_model message when displaying [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <[email protected]> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: a991236 Pull Request resolved: #883
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 8755417 Pull Request resolved: #883
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 078cf56 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: e641533 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 194baee Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉 LGTM. Do you plan to add bool mask indices support in this or another PR?
@@ -473,6 +471,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, | |||
input.value = subarg | |||
sequence_input.append(input) | |||
ort_inputs[input_name] = subarg | |||
else: | |||
sequence_input.append(subarg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put comment here explaining why this is needed in case we forget?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider elif None? It's easier to catch what we are not expecting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I wonder if there are things we don't expect that can sneak in? Nested lists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea. Just in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. Created #892 for now.
I plan to do it in another PR to keep this one manageable. |
Codecov Report
@@ Coverage Diff @@
## main #883 +/- ##
==========================================
- Coverage 76.59% 76.51% -0.09%
==========================================
Files 112 112
Lines 13408 13408
Branches 1348 1348
==========================================
- Hits 10270 10259 -11
- Misses 2801 2812 +11
Partials 337 337 |
) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #883 * #882 * __->__ #881 The torchscript ONNX graph generator creates numeric value names by default (`0`, `1`). These are not legal ONNX tensor names, since ONNX requires the names to be valid C variable names. This change updates the names by prepending a prefix `_val_` or `_const_` to make them valid ONNX names. It also improves readability by making the names less likely to be confused with shape values. I decided to use the `_` prefix to reduce the chance of name collision with FX names. After: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] _val_10) { _val_2 = Transpose <perm = [0, 1, 2, 3]> (input_0) _val_3 = Max (input_1_3) _val_4 = Shape <start = 0> (_val_3) _val_5 = Expand (input_1_3, _val_4) _const_6 = Constant <value = int64 {-1}> () _val_7 = Unsqueeze (_val_5, _const_6) _val_8 = Concat <axis = -1> (_val_7) _val_9 = GatherND <batch_dims = 0> (_val_2, _val_8) _val_10 = Transpose <perm = [0, 1, 2, 3]> (_val_9) } ``` Before: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] 10) { 2 = Transpose <perm = [0, 1, 2, 3]> (input_0) 3 = Max (input_1_3) 4 = Shape <start = 0> (3) 5 = Expand (input_1_3, 4) 6 = Constant <value = int64 {-1}> () 7 = Unsqueeze (5, 6) 8 = Concat <axis = -1> (7) 9 = GatherND <batch_dims = 0> (2, 8) 10 = Transpose <perm = [0, 1, 2, 3]> (9) } ```
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #883 * __->__ #882 * #881 Remove the duplicate check_model message when displaying
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 4b050c3 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: b22c3f8 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 1c8f0e2 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
index_1d = common_methods_invocations.index_variable(2, s, device=device) | ||
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device) | ||
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device) | ||
test_args = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would itertools.{product,permutation,combination} help in ensuring all combinations are covered and make the code shorter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me try that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I tested with itertools but realized some combinations are invalid so we cannot enumerate them all using itertools. For the sake of clarity I propose that we keep the current explicit tests. I also added more test cases and comments
@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any: | |||
|
|||
|
|||
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: | |||
"""Converts kwargs to be compatible with ONNX Runtime. | |||
|
|||
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my knowledge, does ort support torch.bool now? or was this docstring outdated already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't know if ORT supports bool (it should?), but I think this message was a mistake by copilot because we don't actually have this conversion logic as code. If we see issues with ORT I will make adjustments.
index_1d = common_methods_invocations.index_variable(2, s, device=device) | ||
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device) | ||
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device) | ||
test_args = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could itertools.product
(or friends) with length 1 to 4 could shorten this listen and ensure no combination is left out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above. Turns out some combinations are invalid to torch and it may be better to specify explicitly.
…)" --- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
--- **This change implements the logic for `aten::index` and adds tests for different nd index combinations and permutations.** ## Understanding `aten::index` For `arg0` with shape `[7, 3, 4, 5, 6]` The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` None in `indices` are like fillers for dimensions that cannot be removed in the process. ## Gather op reference - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> [ghstack-poisoned]
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 4e36b48 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
Stack from ghstack (oldest at bottom):
This change implements the logic for
aten::index
and adds tests for different nd index combinations and permutations.Understanding
aten::index
For
arg0
with shape[7, 3, 4, 5, 6]
The indexing operation
arg0[0, :, 1:2, tensor([[4,5]])]
will be translated toHere,
indices = [None, None, arg1]
is equivalent toindices = [None, None, arg1, None]
arg0[0, :, 1:2, tensor([[4,5]])]
is equivalent toarg0[0, :, 1:2, tensor([[4,5]]), :]
None in
indices
are like fillers for dimensions that cannot be removed in the process.Gather op reference
Co-authored-by: BowenBao [email protected]