From 576baac90fe16513a8896a58b88ff439cefd6eb5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 11 Jul 2023 18:58:03 +0000 Subject: [PATCH 01/12] wip --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 36a90d8ac6..f6bf8aa4d6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3026,10 +3026,10 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_index(self: TensorType, indices: Optional[Sequence[TensorType]]) -> TensorType: +def aten_index(self: TTensor, indices: Sequence[INT64]) -> TTensor: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" - raise NotImplementedError() + return op.Gather(self, indices) def aten_index_add( From fbfcf6697f1f81416efd5723586681f9f4a40dc4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Jul 2023 00:38:44 +0000 Subject: [PATCH 02/12] wip --- .../function_libs/torch_lib/ops/core.py | 7 ++++- .../function_libs/torch_lib/extra_opinfo.py | 26 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f6bf8aa4d6..0cb3f00dde 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3026,10 +3026,15 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::index.Tensor") def aten_index(self: TTensor, indices: Sequence[INT64]) -> TTensor: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" - return op.Gather(self, indices) + result = self + for i in range(op.SequenceLength(indices)): + result = op.Gather(result, op.SequenceAt(indices, i)) + + return result def aten_index_add( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 85b66bacd9..8334168203 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -446,6 +446,22 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs) +def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + s = 5 + test_args = [ + ([common_methods_invocations.index_variable(2, s, device=device)],), + # ([torch.tensor()],) + ] + + for args in test_args: + yield opinfo_core.SampleInput(make_arg((s, s, s)), args=args) + + def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -581,9 +597,17 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra skips=(), supports_out=False, ), + opinfo_core.OpInfo( + "aten.index.Tensor", + dtypes=common_dtype.all_types_and_complex_and( + torch.bool, torch.float16, torch.bfloat16, torch.chalf + ), + aten_name="index", + op=torch.ops.aten.index.Tensor, + sample_inputs_func=sample_inputs_index, + ), opinfo_core.OpInfo( "layer_norm", - aliases=("layer_norm",), aten_name="layer_norm", dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), sample_inputs_func=sample_inputs_layer_norm, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 3362ec2745..28dd34585e 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -601,6 +601,7 @@ def _where_input_wrangler( TorchLibOpInfo("gt", core_ops.aten_gt), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB + TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, From 9560e7eec5428d2d68c29edda17b64f55f105697 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 11 Jul 2023 19:41:34 -0700 Subject: [PATCH 03/12] hack up index --- .../function_libs/torch_lib/ops/core.py | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0cb3f00dde..e753a7e63c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -33,6 +33,7 @@ ) from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +import numpy as np _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 @@ -3025,16 +3026,44 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() +def are_consecutive(lst): + if len(lst) == 0: + return True + else: + return sorted(lst) == list(range(min(lst), max(lst)+1)) + +def _is_none_in_middle(indices): + not_none_indices = [i for i, idx in enumerate(indices) if idx is not None] + return not are_consecutive(not_none_indices) -@torch_op("aten::index.Tensor") -def aten_index(self: TTensor, indices: Sequence[INT64]) -> TTensor: +@torch_op("aten::index", trace_only=True) +def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" - result = self - for i in range(op.SequenceLength(indices)): - result = op.Gather(result, op.SequenceAt(indices, i)) + ordered_indices = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + ordered_indices += [i for i in range(len(ordered_indices), len(self.shape))] - return result + self = op.Transpose(self, perm=ordered_indices) + + # Need broadcast concat lol. + not_none_indices = [idx for idx in indices if idx is not None] + broadcast_shape = np.broadcast_shapes(*[idx.shape for idx in not_none_indices]) + final_index = op.Concat(*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), axis=-1) + + self = op.GatherND(self, final_index, batch_dims=0) + if _is_none_in_middle(indices): + return self + + # Need to transpose back. + adv_res_rank = len(broadcast_shape) + # [adv1, adv2, ..., adv_res_rank, x1, x2, ..., xk, ..., xn] -> [x1, x2, ..., xk, adv1, ..., adv_res_rank, ..., xn] + perm = ( + list(range(adv_res_rank, adv_res_rank + ordered_indices[0])) + + list(range(0, adv_res_rank)) + + list(range(adv_res_rank + ordered_indices[0], len(ordered_indices) + adv_res_rank - len(not_none_indices))) + ) + + return op.Transpose(self, perm=perm) def aten_index_add( From 64137855b7521c1e8e25483e59bde86077acff7c Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 11 Jul 2023 19:44:05 -0700 Subject: [PATCH 04/12] linter --- .../function_libs/torch_lib/ops/core.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e753a7e63c..b6f9460976 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,6 +14,8 @@ import math from typing import Any, Optional, Sequence, Tuple, Union +import numpy as np + from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( @@ -33,7 +35,6 @@ ) from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType -import numpy as np _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 @@ -3026,29 +3027,35 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() + def are_consecutive(lst): if len(lst) == 0: return True else: - return sorted(lst) == list(range(min(lst), max(lst)+1)) + return sorted(lst) == list(range(min(lst), max(lst) + 1)) + def _is_none_in_middle(indices): not_none_indices = [i for i, idx in enumerate(indices) if idx is not None] return not are_consecutive(not_none_indices) + @torch_op("aten::index", trace_only=True) def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" ordered_indices = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) - ordered_indices += [i for i in range(len(ordered_indices), len(self.shape))] + ordered_indices += list(range(len(ordered_indices), len(self.shape))) self = op.Transpose(self, perm=ordered_indices) # Need broadcast concat lol. not_none_indices = [idx for idx in indices if idx is not None] broadcast_shape = np.broadcast_shapes(*[idx.shape for idx in not_none_indices]) - final_index = op.Concat(*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), axis=-1) + final_index = op.Concat( + *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), + axis=-1, + ) self = op.GatherND(self, final_index, batch_dims=0) if _is_none_in_middle(indices): @@ -3058,9 +3065,14 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy adv_res_rank = len(broadcast_shape) # [adv1, adv2, ..., adv_res_rank, x1, x2, ..., xk, ..., xn] -> [x1, x2, ..., xk, adv1, ..., adv_res_rank, ..., xn] perm = ( - list(range(adv_res_rank, adv_res_rank + ordered_indices[0])) + - list(range(0, adv_res_rank)) + - list(range(adv_res_rank + ordered_indices[0], len(ordered_indices) + adv_res_rank - len(not_none_indices))) + list(range(adv_res_rank, adv_res_rank + ordered_indices[0])) + + list(range(0, adv_res_rank)) + + list( + range( + adv_res_rank + ordered_indices[0], + len(ordered_indices) + adv_res_rank - len(not_none_indices), + ) + ) ) return op.Transpose(self, perm=perm) From d1acb440b75ab90bde647e6fe2612e8381a3bb1b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Jul 2023 23:19:17 +0000 Subject: [PATCH 05/12] Clean up --- .../function_libs/torch_lib/ops/core.py | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b6f9460976..21bb4643cd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3028,52 +3028,56 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() -def are_consecutive(lst): - if len(lst) == 0: +def _are_consecutive(sorted_list: Sequence[int]) -> bool: + """Returns True if a sorted list contains consecutive numbers.""" + if not sorted_list: return True - else: - return sorted(lst) == list(range(min(lst), max(lst) + 1)) + + return sorted_list == list(range(min(sorted_list), max(sorted_list) + 1)) -def _is_none_in_middle(indices): +def _has_none_in_middle(indices) -> bool: + """Returns True if there is a None in the middle of the list.""" not_none_indices = [i for i, idx in enumerate(indices) if idx is not None] - return not are_consecutive(not_none_indices) + return not _are_consecutive(not_none_indices) @torch_op("aten::index", trace_only=True) def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" + # Move the None indices to the end ordered_indices = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) - ordered_indices += list(range(len(ordered_indices), len(self.shape))) - + # Fill the list with the remaining indices up to the rank of the tensor + ordered_indices = [*ordered_indices, *range(len(ordered_indices), len(self.shape))] + # Transpose the tensor to the axis in the order of [provided, None, not provided] self = op.Transpose(self, perm=ordered_indices) - # Need broadcast concat lol. + # Broadcast the indices to the same shape then concatenate not_none_indices = [idx for idx in indices if idx is not None] - broadcast_shape = np.broadcast_shapes(*[idx.shape for idx in not_none_indices]) + broadcast_shape = list(np.broadcast_shapes(*[idx.shape for idx in not_none_indices])) + broadcast_shape = op.Constant(value_ints=broadcast_shape) + print(type(broadcast_shape), broadcast_shape) final_index = op.Concat( *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), axis=-1, ) self = op.GatherND(self, final_index, batch_dims=0) - if _is_none_in_middle(indices): + if _has_none_in_middle(indices): return self # Need to transpose back. - adv_res_rank = len(broadcast_shape) + adv_res_rank = len(not_none_indices) # [adv1, adv2, ..., adv_res_rank, x1, x2, ..., xk, ..., xn] -> [x1, x2, ..., xk, adv1, ..., adv_res_rank, ..., xn] - perm = ( - list(range(adv_res_rank, adv_res_rank + ordered_indices[0])) - + list(range(0, adv_res_rank)) - + list( - range( - adv_res_rank + ordered_indices[0], - len(ordered_indices) + adv_res_rank - len(not_none_indices), - ) - ) - ) + perm = [ + *range(adv_res_rank, adv_res_rank + ordered_indices[0]), + *range(0, adv_res_rank), + *range( + adv_res_rank + ordered_indices[0], + len(ordered_indices) + adv_res_rank - len(not_none_indices), + ), + ] return op.Transpose(self, perm=perm) From 964296f32d4f6e12f2af7a9cdbefd401b91f031d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Jul 2023 23:56:21 +0000 Subject: [PATCH 06/12] snapshot --- .../function_libs/torch_lib/ops/core.py | 57 +++++++++++++++---- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 21bb4643cd..e6fd6eb225 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3046,12 +3046,23 @@ def _has_none_in_middle(indices) -> bool: def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" - # Move the None indices to the end - ordered_indices = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + # NOTE: Understanding aten::index + # The indexing operation x[0, :, 1:2, tensor([[4,5]])] will be translated to + # A bunch of Slice operations, followed by aten::index with + # self = rank ? tensor + # indices = [None, None, None, tensor([[4,5]])] + # TODO(justinchuby): Clarify what happens with 0 + + # reordered_positions is the permutation of the index positions where + # positions with None are move to the end + # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] + reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) # Fill the list with the remaining indices up to the rank of the tensor - ordered_indices = [*ordered_indices, *range(len(ordered_indices), len(self.shape))] - # Transpose the tensor to the axis in the order of [provided, None, not provided] - self = op.Transpose(self, perm=ordered_indices) + # For example, if indices = [None, 1, None, 2], and the rank of self is 6, + # then reordered_positions = [1, 3, 0, 2, 4, 5] + reordered_positions = [*reordered_positions, *range(len(reordered_positions), len(self.shape))] + # Transpose self according to the reordered positions + self = op.Transpose(self, perm=reordered_positions) # Broadcast the indices to the same shape then concatenate not_none_indices = [idx for idx in indices if idx is not None] @@ -3064,18 +3075,40 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy ) self = op.GatherND(self, final_index, batch_dims=0) + if _has_none_in_middle(indices): + # If there is None in the middle, Advanced Indexing cannot decide where to put + # the new dimensions. So it places them in the front, like GatherND does. return self - # Need to transpose back. - adv_res_rank = len(not_none_indices) - # [adv1, adv2, ..., adv_res_rank, x1, x2, ..., xk, ..., xn] -> [x1, x2, ..., xk, adv1, ..., adv_res_rank, ..., xn] + # When the indices are consecutive, Advanced Indexing will place the new dimensions + # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. + # + # Input index axes (three parts): + # [ + # x_None_front_1, ... x_None_front_m, + # x1, ..., xk, + # x_None_back_1, ..., x_None_back_m + # ] + # GatherND result axes: + # [ + # *broadcasted_shape(x1, x2, ..., xk), + # x_None_front_1, ... x_None_front_m, + # x_None_back_1, ..., x_None_back_m + # ] + # (Transpose here) + # Advanced indexing result axes: [x_None_front_1, ... x_None_front_m, *brocasted_shape(x1, x2, ..., xk), x_None_back_1, ..., x_None_back_m] + # + # Need to transpose the result of GatherND to match this axes ordering. + advanced_indexing_rank = len(not_none_indices) + first_not_none_index = reordered_positions[0] # x_None_front_m + 1 + x_none_back_1_position = advanced_indexing_rank + first_not_none_index perm = [ - *range(adv_res_rank, adv_res_rank + ordered_indices[0]), - *range(0, adv_res_rank), + *range(advanced_indexing_rank, x_none_back_1_position), + *range(0, advanced_indexing_rank), *range( - adv_res_rank + ordered_indices[0], - len(ordered_indices) + adv_res_rank - len(not_none_indices), + x_none_back_1_position, + len(reordered_positions) + advanced_indexing_rank - len(not_none_indices), ), ] From 6fd9c964e9456945e3e0bbba5af7dbcc442316c3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Jul 2023 00:09:43 +0000 Subject: [PATCH 07/12] notes --- .../function_libs/torch_lib/ops/core.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e6fd6eb225..efb7604d3d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3044,23 +3044,35 @@ def _has_none_in_middle(indices) -> bool: @torch_op("aten::index", trace_only=True) def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: - """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" - - # NOTE: Understanding aten::index - # The indexing operation x[0, :, 1:2, tensor([[4,5]])] will be translated to - # A bunch of Slice operations, followed by aten::index with - # self = rank ? tensor - # indices = [None, None, None, tensor([[4,5]])] - # TODO(justinchuby): Clarify what happens with 0 + """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + + NOTE: Understanding `aten::index` + For `arg0` with shape `[7, 3, 4, 5, 6]` + The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to + + ``` + +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); + +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); + +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); + +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); + ``` + + Here, + - `indices = [None, None, arg1]` is equivalent to `[None, None, arg1, None]` + - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` + """ # reordered_positions is the permutation of the index positions where - # positions with None are move to the end + # positions with None are move to the end of the list # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) - # Fill the list with the remaining indices up to the rank of the tensor + # Fill the list with the remaining indices up to the rank of the tensor self. # For example, if indices = [None, 1, None, 2], and the rank of self is 6, # then reordered_positions = [1, 3, 0, 2, 4, 5] - reordered_positions = [*reordered_positions, *range(len(reordered_positions), len(self.shape))] + reordered_positions = [ + *reordered_positions, + *range(len(reordered_positions), len(self.shape)), + ] # Transpose self according to the reordered positions self = op.Transpose(self, perm=reordered_positions) From 0a327d10769142c7660fc5f35d2452c4908258ab Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Jul 2023 00:24:46 +0000 Subject: [PATCH 08/12] Clean up --- .../function_libs/torch_lib/ops/core.py | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index efb7604d3d..982c9851dc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3051,17 +3051,21 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to ``` - +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); - +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); - +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); - +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); + +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); + +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); + +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); + +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); ``` Here, - - `indices = [None, None, arg1]` is equivalent to `[None, None, arg1, None]` + - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` + + None in `indices` are like fillers for dimensions that cannot be removed in the process. """ + self_rank = len(self.shape) + # reordered_positions is the permutation of the index positions where # positions with None are move to the end of the list # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] @@ -3071,7 +3075,7 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy # then reordered_positions = [1, 3, 0, 2, 4, 5] reordered_positions = [ *reordered_positions, - *range(len(reordered_positions), len(self.shape)), + *range(len(reordered_positions), self_rank), ] # Transpose self according to the reordered positions self = op.Transpose(self, perm=reordered_positions) @@ -3079,8 +3083,8 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy # Broadcast the indices to the same shape then concatenate not_none_indices = [idx for idx in indices if idx is not None] broadcast_shape = list(np.broadcast_shapes(*[idx.shape for idx in not_none_indices])) + advanced_indexing_rank = len(broadcast_shape) broadcast_shape = op.Constant(value_ints=broadcast_shape) - print(type(broadcast_shape), broadcast_shape) final_index = op.Concat( *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), axis=-1, @@ -3109,19 +3113,28 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy # x_None_back_1, ..., x_None_back_m # ] # (Transpose here) - # Advanced indexing result axes: [x_None_front_1, ... x_None_front_m, *brocasted_shape(x1, x2, ..., xk), x_None_back_1, ..., x_None_back_m] + # Advanced indexing result axes: + # [ + # x_None_front_1, ... x_None_front_m, + # *brocasted_shape(x1, x2, ..., xk), + # x_None_back_1, ..., x_None_back_m + # ] # # Need to transpose the result of GatherND to match this axes ordering. - advanced_indexing_rank = len(not_none_indices) - first_not_none_index = reordered_positions[0] # x_None_front_m + 1 - x_none_back_1_position = advanced_indexing_rank + first_not_none_index + first_not_none_position = reordered_positions[0] # x_None_front_m + 1 + starting_position_of_none_in_back = ( + advanced_indexing_rank + first_not_none_position + ) # x_None_back_1 + result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank perm = [ - *range(advanced_indexing_rank, x_none_back_1_position), - *range(0, advanced_indexing_rank), *range( - x_none_back_1_position, - len(reordered_positions) + advanced_indexing_rank - len(not_none_indices), - ), + advanced_indexing_rank, starting_position_of_none_in_back + ), # None_front_1...x_None_back_1 + *range(0, advanced_indexing_rank), # 0...len(broadcasted_shape) + *range( + starting_position_of_none_in_back, + result_rank, + ), # None_back_1...None_back_m ] return op.Transpose(self, perm=perm) From 38e8814d1f6d765fd0e905240bac24c94d8f136d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Jul 2023 00:45:28 +0000 Subject: [PATCH 09/12] rankrank --- .../function_libs/torch_lib/ops/core.py | 24 ++++++++++++++++--- .../function_libs/torch_lib/extra_opinfo.py | 2 +- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 982c9851dc..ffce83a63d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3042,6 +3042,12 @@ def _has_none_in_middle(indices) -> bool: return not _are_consecutive(not_none_indices) +def _shape_of_broadcast_tensors(*args: TensorType) -> INT64: + """Returns the broadcasted shape of the given tensors.""" + broadcasted = op.Max(*args) + return op.Shape(broadcasted) + + @torch_op("aten::index", trace_only=True) def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor @@ -3065,6 +3071,7 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy """ self_rank = len(self.shape) + advanced_indexing_rank = max(len(index.shape) for index in indices if index is not None) # reordered_positions is the permutation of the index positions where # positions with None are move to the end of the list @@ -3082,9 +3089,20 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy # Broadcast the indices to the same shape then concatenate not_none_indices = [idx for idx in indices if idx is not None] - broadcast_shape = list(np.broadcast_shapes(*[idx.shape for idx in not_none_indices])) - advanced_indexing_rank = len(broadcast_shape) - broadcast_shape = op.Constant(value_ints=broadcast_shape) + + + # --- DEBUG + broadcast_shape_comp = list(np.broadcast_shapes(*[idx.shape for idx in not_none_indices])) + advanced_indexing_rank_comp = len(broadcast_shape_comp) + + print("broadcast_shape_comp", broadcast_shape_comp) + print("advanced_indexing_rank", advanced_indexing_rank) + print("advanced_indexing_rank_comp", advanced_indexing_rank_comp) + assert advanced_indexing_rank == advanced_indexing_rank_comp + # --- DEBUG + + + broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) final_index = op.Concat( *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), axis=-1, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 8334168203..7f5785507f 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -454,7 +454,7 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): ) s = 5 test_args = [ - ([common_methods_invocations.index_variable(2, s, device=device)],), + ([common_methods_invocations.index_variable(2, 4, device=device)],), # ([torch.tensor()],) ] From 0ad2a8afdd3a4413d0d1fb7ad161771534654fdc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Jul 2023 00:45:49 +0000 Subject: [PATCH 10/12] remove debug --- onnxscript/function_libs/torch_lib/ops/core.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ffce83a63d..8a4417029a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3089,19 +3089,6 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy # Broadcast the indices to the same shape then concatenate not_none_indices = [idx for idx in indices if idx is not None] - - - # --- DEBUG - broadcast_shape_comp = list(np.broadcast_shapes(*[idx.shape for idx in not_none_indices])) - advanced_indexing_rank_comp = len(broadcast_shape_comp) - - print("broadcast_shape_comp", broadcast_shape_comp) - print("advanced_indexing_rank", advanced_indexing_rank) - print("advanced_indexing_rank_comp", advanced_indexing_rank_comp) - assert advanced_indexing_rank == advanced_indexing_rank_comp - # --- DEBUG - - broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) final_index = op.Concat( *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), From 2d0923ae70f25657f5c5462fd751725fe709cf07 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Jul 2023 00:58:18 +0000 Subject: [PATCH 11/12] snap --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++++- .../tests/function_libs/torch_lib/extra_opinfo.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8a4417029a..a777d131ba 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3071,7 +3071,10 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy """ self_rank = len(self.shape) - advanced_indexing_rank = max(len(index.shape) for index in indices if index is not None) + index_ranks = [len(index.shape) for index in indices if index is not None] + print("index_ranks: ", index_ranks) + print("indices: ", indices) + advanced_indexing_rank = max(index_ranks) # reordered_positions is the permutation of the index positions where # positions with None are move to the end of the list diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 7f5785507f..eece81c736 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -453,13 +453,22 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad ) s = 5 + index_1d = common_methods_invocations.index_variable(2, s, device=device) + index_2d = common_methods_invocations.index_variable((2, s+1), s, device=device) + index_3d = common_methods_invocations.index_variable((2, s+1, s+2), s, device=device) test_args = [ - ([common_methods_invocations.index_variable(2, 4, device=device)],), - # ([torch.tensor()],) + ([index_1d],), + ([None, index_1d],), + ([None, None, None, index_1d],), + ([index_1d, None],), + ([index_1d, None, None],), + ([None, index_1d, None, index_1d],), + ([index_1d, None, index_1d, None],), + ([None, index_1d, index_1d, None],), ] for args in test_args: - yield opinfo_core.SampleInput(make_arg((s, s, s)), args=args) + yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): From f7073604b9fc25beffaa46648817bb9ceca4a35a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 17 Jul 2023 16:55:15 +0000 Subject: [PATCH 12/12] test --- onnxscript/function_libs/torch_lib/ops/core.py | 2 -- .../function_libs/torch_lib/ops_test_common.py | 13 +++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a777d131ba..db2ea158aa 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,8 +14,6 @@ import math from typing import Any, Optional, Sequence, Tuple, Union -import numpy as np - from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index fae2f709ad..9cc21de90b 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any: if isinstance(input, (tuple, list)): if len(input) == 0: return np.array((), dtype=np.int64) - if isinstance(input[0], torch.Tensor): + if any(isinstance(x, torch.Tensor) for x in input): + # The list can be Optional[Tensor], e.g. [None, Tensor, None] etc. return [convert_tensor_to_numpy(x) for x in input] if isinstance(input[0], bool): return np.array(input, dtype=np.bool_) @@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any: def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: - """Converts kwargs to be compatible with ONNX Runtime. - - ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. - """ + """Converts kwargs to be compatible with ONNX Runtime.""" new_kwargs = {} for key, value in kwargs.items(): if key == "device": @@ -515,10 +513,9 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, # Make sure the model is valid try: onnx.checker.check_model(onnx_model, full_check=True) - except onnx.checker.ValidationError as e: + except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: raise AssertionError( - f"ONNX model is invalid: {e}. " - f"Model:\n" + f"ONNX model is invalid, Model:\n" f"{onnxscript.proto2text(onnx_model)}" ) from e