diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 59dcec79cb..0620e16e1a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3020,10 +3020,124 @@ def aten_imag(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_index(self: TensorType, indices: Optional[Sequence[TensorType]]) -> TensorType: - """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor""" +def _are_consecutive(sorted_list: Sequence[int]) -> bool: + """Returns True if a sorted list contains consecutive numbers.""" + if not sorted_list: + return True - raise NotImplementedError() + return sorted_list == list(range(min(sorted_list), max(sorted_list) + 1)) + + +def _has_none_in_middle(indices) -> bool: + """Returns True if there is a None in the middle of the list.""" + not_none_indices = [i for i, idx in enumerate(indices) if idx is not None] + return not _are_consecutive(not_none_indices) + + +def _shape_of_broadcast_tensors(*args: TensorType) -> INT64: + """Returns the broadcasted shape of the given tensors.""" + broadcasted = op.Max(*args) + return op.Shape(broadcasted) + + +@torch_op("aten::index", trace_only=True) +def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: + """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + + NOTE: Understanding `aten::index` + For `arg0` with shape `[7, 3, 4, 5, 6]` + The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to + + ``` + +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0); + +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807); + +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2); + +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]); + ``` + + Here, + - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]` + - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]` + + None in `indices` are like fillers for dimensions that cannot be removed in the process. + """ + + self_rank = len(self.shape) + index_ranks = [len(index.shape) for index in indices if index is not None] + print("index_ranks: ", index_ranks) + print("indices: ", indices) + advanced_indexing_rank = max(index_ranks) + + # reordered_positions is the permutation of the index positions where + # positions with None are move to the end of the list + # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] + reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + # Fill the list with the remaining indices up to the rank of the tensor self. + # For example, if indices = [None, 1, None, 2], and the rank of self is 6, + # then reordered_positions = [1, 3, 0, 2, 4, 5] + reordered_positions = [ + *reordered_positions, + *range(len(reordered_positions), self_rank), + ] + # Transpose self according to the reordered positions + self = op.Transpose(self, perm=reordered_positions) + + # Broadcast the indices to the same shape then concatenate + not_none_indices = [idx for idx in indices if idx is not None] + broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) + final_index = op.Concat( + *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), + axis=-1, + ) + + self = op.GatherND(self, final_index, batch_dims=0) + + if _has_none_in_middle(indices): + # If there is None in the middle, Advanced Indexing cannot decide where to put + # the new dimensions. So it places them in the front, like GatherND does. + return self + + # When the indices are consecutive, Advanced Indexing will place the new dimensions + # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. + # + # Input index axes (three parts): + # [ + # x_None_front_1, ... x_None_front_m, + # x1, ..., xk, + # x_None_back_1, ..., x_None_back_m + # ] + # GatherND result axes: + # [ + # *broadcasted_shape(x1, x2, ..., xk), + # x_None_front_1, ... x_None_front_m, + # x_None_back_1, ..., x_None_back_m + # ] + # (Transpose here) + # Advanced indexing result axes: + # [ + # x_None_front_1, ... x_None_front_m, + # *brocasted_shape(x1, x2, ..., xk), + # x_None_back_1, ..., x_None_back_m + # ] + # + # Need to transpose the result of GatherND to match this axes ordering. + first_not_none_position = reordered_positions[0] # x_None_front_m + 1 + starting_position_of_none_in_back = ( + advanced_indexing_rank + first_not_none_position + ) # x_None_back_1 + result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank + perm = [ + *range( + advanced_indexing_rank, starting_position_of_none_in_back + ), # None_front_1...x_None_back_1 + *range(0, advanced_indexing_rank), # 0...len(broadcasted_shape) + *range( + starting_position_of_none_in_back, + result_rank, + ), # None_back_1...None_back_m + ] + + return op.Transpose(self, perm=perm) def aten_index_add( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 1995ec1017..38ba7f998f 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -444,6 +444,31 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs) +def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + s = 5 + index_1d = common_methods_invocations.index_variable(2, s, device=device) + index_2d = common_methods_invocations.index_variable((2, s+1), s, device=device) + index_3d = common_methods_invocations.index_variable((2, s+1, s+2), s, device=device) + test_args = [ + ([index_1d],), + ([None, index_1d],), + ([None, None, None, index_1d],), + ([index_1d, None],), + ([index_1d, None, None],), + ([None, index_1d, None, index_1d],), + ([index_1d, None, index_1d, None],), + ([None, index_1d, index_1d, None],), + ] + + for args in test_args: + yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) + + def sample_inputs_native_dropout( op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs ): @@ -616,6 +641,15 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra sample_inputs_func=sample_inputs_convolution, supports_out=False, ), + opinfo_core.OpInfo( + "aten.index.Tensor", + dtypes=common_dtype.all_types_and_complex_and( + torch.bool, torch.float16, torch.bfloat16, torch.chalf + ), + aten_name="index", + op=torch.ops.aten.index.Tensor, + sample_inputs_func=sample_inputs_index, + ), opinfo_core.OpInfo( "ops.aten.layer_norm", aten_name="layer_norm", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 63c4bebaca..78968a93e0 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any: if isinstance(input, (tuple, list)): if len(input) == 0: return np.array((), dtype=np.int64) - if isinstance(input[0], torch.Tensor): + if any(isinstance(x, torch.Tensor) for x in input): + # The list can be Optional[Tensor], e.g. [None, Tensor, None] etc. return [convert_tensor_to_numpy(x) for x in input] if isinstance(input[0], bool): return np.array(input, dtype=np.bool_) @@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any: def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: - """Converts kwargs to be compatible with ONNX Runtime. - - ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. - """ + """Converts kwargs to be compatible with ONNX Runtime.""" new_kwargs = {} for key, value in kwargs.items(): if key == "device": @@ -515,7 +513,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, # Make sure the model is valid try: onnx.checker.check_model(onnx_model, full_check=True) - except onnx.checker.ValidationError as e: + except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: raise AssertionError( f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}" ) from e diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 598d527010..f5a18d2b47 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -607,6 +607,7 @@ def _where_input_wrangler( TorchLibOpInfo("gt", core_ops.aten_gt), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB + TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool,