Skip to content

Implement aten::index | feat(torchlib) #862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 17, 2023
120 changes: 117 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,10 +3020,124 @@ def aten_imag(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_index(self: TensorType, indices: Optional[Sequence[TensorType]]) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
def _are_consecutive(sorted_list: Sequence[int]) -> bool:
"""Returns True if a sorted list contains consecutive numbers."""
if not sorted_list:
return True

raise NotImplementedError()
return sorted_list == list(range(min(sorted_list), max(sorted_list) + 1))


def _has_none_in_middle(indices) -> bool:
"""Returns True if there is a None in the middle of the list."""
not_none_indices = [i for i, idx in enumerate(indices) if idx is not None]
return not _are_consecutive(not_none_indices)


def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
"""Returns the broadcasted shape of the given tensors."""
broadcasted = op.Max(*args)
return op.Shape(broadcasted)


@torch_op("aten::index", trace_only=True)
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work for all cases except for bool mask index, if possible.

Let me know if you can find a bug! @justinchuby

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try op.NonZero to convert bool mask to integer index.

"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor

NOTE: Understanding `aten::index`
For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to

```
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```

Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`

None in `indices` are like fillers for dimensions that cannot be removed in the process.
"""

self_rank = len(self.shape)
index_ranks = [len(index.shape) for index in indices if index is not None]
print("index_ranks: ", index_ranks)
print("indices: ", indices)
advanced_indexing_rank = max(index_ranks)

# reordered_positions is the permutation of the index positions where
# positions with None are move to the end of the list
# For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i))
# Fill the list with the remaining indices up to the rank of the tensor self.
# For example, if indices = [None, 1, None, 2], and the rank of self is 6,
# then reordered_positions = [1, 3, 0, 2, 4, 5]
reordered_positions = [
*reordered_positions,
*range(len(reordered_positions), self_rank),
]
# Transpose self according to the reordered positions
self = op.Transpose(self, perm=reordered_positions)

# Broadcast the indices to the same shape then concatenate
not_none_indices = [idx for idx in indices if idx is not None]
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
final_index = op.Concat(
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
axis=-1,
)

self = op.GatherND(self, final_index, batch_dims=0)

if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return self

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
#
# Input index axes (three parts):
# [
# x_None_front_1, ... x_None_front_m,
# x1, ..., xk,
# x_None_back_1, ..., x_None_back_m
# ]
# GatherND result axes:
# [
# *broadcasted_shape(x1, x2, ..., xk),
# x_None_front_1, ... x_None_front_m,
# x_None_back_1, ..., x_None_back_m
# ]
# (Transpose here)
# Advanced indexing result axes:
# [
# x_None_front_1, ... x_None_front_m,
# *brocasted_shape(x1, x2, ..., xk),
# x_None_back_1, ..., x_None_back_m
# ]
#
# Need to transpose the result of GatherND to match this axes ordering.
first_not_none_position = reordered_positions[0] # x_None_front_m + 1
starting_position_of_none_in_back = (
advanced_indexing_rank + first_not_none_position
) # x_None_back_1
result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank
perm = [
*range(
advanced_indexing_rank, starting_position_of_none_in_back
), # None_front_1...x_None_back_1
*range(0, advanced_indexing_rank), # 0...len(broadcasted_shape)
*range(
starting_position_of_none_in_back,
result_rank,
), # None_back_1...None_back_m
]

return op.Transpose(self, perm=perm)


def aten_index_add(
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,31 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs)


def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
s = 5
index_1d = common_methods_invocations.index_variable(2, s, device=device)
index_2d = common_methods_invocations.index_variable((2, s+1), s, device=device)
index_3d = common_methods_invocations.index_variable((2, s+1, s+2), s, device=device)
test_args = [
([index_1d],),
([None, index_1d],),
([None, None, None, index_1d],),
([index_1d, None],),
([index_1d, None, None],),
([None, index_1d, None, index_1d],),
([index_1d, None, index_1d, None],),
([None, index_1d, index_1d, None],),
]

for args in test_args:
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_native_dropout(
op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs
):
Expand Down Expand Up @@ -616,6 +641,15 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
sample_inputs_func=sample_inputs_convolution,
supports_out=False,
),
opinfo_core.OpInfo(
"aten.index.Tensor",
dtypes=common_dtype.all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16, torch.chalf
),
aten_name="index",
op=torch.ops.aten.index.Tensor,
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
10 changes: 4 additions & 6 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any:
if isinstance(input, (tuple, list)):
if len(input) == 0:
return np.array((), dtype=np.int64)
if isinstance(input[0], torch.Tensor):
if any(isinstance(x, torch.Tensor) for x in input):
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
return [convert_tensor_to_numpy(x) for x in input]
if isinstance(input[0], bool):
return np.array(input, dtype=np.bool_)
Expand All @@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:


def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Converts kwargs to be compatible with ONNX Runtime.

ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
"""
"""Converts kwargs to be compatible with ONNX Runtime."""
new_kwargs = {}
for key, value in kwargs.items():
if key == "device":
Expand Down Expand Up @@ -515,7 +513,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
# Make sure the model is valid
try:
onnx.checker.check_model(onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
) from e
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def _where_input_wrangler(
TorchLibOpInfo("gt", core_ops.aten_gt),
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
Expand Down