Skip to content

Implement aten::index | feat(torchlib) (#862) #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 115 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,10 +3020,122 @@ def aten_imag(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_index(self: TensorType, indices: Optional[Sequence[TensorType]]) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
def _are_consecutive(sorted_list: Sequence[int]) -> bool:
"""Returns True if a sorted list contains consecutive numbers."""
if not sorted_list:
return True

raise NotImplementedError()
return sorted_list == list(range(min(sorted_list), max(sorted_list) + 1))


def _has_none_in_middle(indices) -> bool:
"""Returns True if there is a None in the middle of the list."""
not_none_indices = [i for i, idx in enumerate(indices) if idx is not None]
return not _are_consecutive(not_none_indices)


def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
"""Returns the broadcasted shape of the given tensors."""
broadcasted = op.Max(*args)
return op.Shape(broadcasted)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor

NOTE: Understanding `aten::index`
For `arg0` with shape `[7, 3, 4, 5, 6]`
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to

```
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
```

Here,
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`

None in `indices` are like fillers for dimensions that cannot be removed in the process.
"""

self_rank = len(self.shape)
index_ranks = [len(index.shape) for index in indices if index is not None]
advanced_indexing_rank = max(index_ranks)

# reordered_positions is the permutation of the index positions where
# positions with None are move to the end of the list
# For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i))
# Fill the list with the remaining indices up to the rank of the tensor self.
# For example, if indices = [None, 1, None, 2], and the rank of self is 6,
# then reordered_positions = [1, 3, 0, 2, 4, 5]
reordered_positions = [
*reordered_positions,
*range(len(reordered_positions), self_rank),
]
# Transpose self according to the reordered positions
self = op.Transpose(self, perm=reordered_positions)

# Broadcast the indices to the same shape then concatenate
not_none_indices = [idx for idx in indices if idx is not None]
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
final_index = op.Concat(
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
axis=-1,
)

self = op.GatherND(self, final_index, batch_dims=0)

if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return self

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
#
# Input index axes (three parts):
# [
# x_None_front_1, ... x_None_front_m,
# x1, ..., xk,
# x_None_back_1, ..., x_None_back_m
# ]
# GatherND result axes:
# [
# *broadcasted_shape(x1, x2, ..., xk),
# x_None_front_1, ... x_None_front_m,
# x_None_back_1, ..., x_None_back_m
# ]
# (Transpose here)
# Advanced indexing result axes:
# [
# x_None_front_1, ... x_None_front_m,
# *brocasted_shape(x1, x2, ..., xk),
# x_None_back_1, ..., x_None_back_m
# ]
#
# Need to transpose the result of GatherND to match this axes ordering.
first_not_none_position = reordered_positions[0] # x_None_front_m + 1
starting_position_of_none_in_back = (
advanced_indexing_rank + first_not_none_position
) # x_None_back_1
result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank
perm = [
*range(
advanced_indexing_rank, starting_position_of_none_in_back
), # None_front_1...x_None_back_1
*range(0, advanced_indexing_rank), # 0...len(broadcasted_shape)
*range(
starting_position_of_none_in_back,
result_rank,
), # None_back_1...None_back_m
]

return op.Transpose(self, perm=perm)


def aten_index_add(
Expand Down
57 changes: 57 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,55 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs)


def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
s = 5
index_1d = common_methods_invocations.index_variable(2, s, device=device)
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device)
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device)
test_args = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would itertools.{product,permutation,combination} help in ensuring all combinations are covered and make the code shorter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me try that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tested with itertools but realized some combinations are invalid so we cannot enumerate them all using itertools. For the sake of clarity I propose that we keep the current explicit tests. I also added more test cases and comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could itertools.product (or friends) with length 1 to 4 could shorten this listen and ensure no combination is left out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above. Turns out some combinations are invalid to torch and it may be better to specify explicitly.

([index_1d],),
([None, index_1d],),
([None, None, None, index_1d],),
([index_1d, None],),
([index_1d, None, None],),
# Extra index
([None, index_1d, None, index_1d],),
([index_1d, None, index_1d, None],),
([None, index_1d, index_1d, None],),
([index_2d],),
([None, index_2d],),
([None, None, None, index_2d],),
([index_2d, None],),
([index_2d, None, None],),
# Extra index
([None, index_2d, None, index_2d],),
([index_2d, None, index_2d, None],),
([None, index_2d, index_2d, None],),
([index_3d],),
([None, index_3d],),
([None, None, None, index_3d],),
([index_3d, None],),
([index_3d, None, None],),
# Extra index
([None, index_3d, None, index_3d],),
([index_3d, None, index_3d, None],),
([None, index_3d, index_3d, None],),
# Mixed indices
([None, index_3d, index_1d, index_2d],),
# All indices are not None
([index_2d, index_3d, index_1d],),
([index_2d, index_3d, index_1d, index_2d],),
]

for args in test_args:
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_native_dropout(
op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs
):
Expand Down Expand Up @@ -616,6 +665,14 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
sample_inputs_func=sample_inputs_convolution,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.index.Tensor",
aten_name="index.Tensor",
dtypes=common_dtype.all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16, torch.chalf
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
14 changes: 8 additions & 6 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any:
if isinstance(input, (tuple, list)):
if len(input) == 0:
return np.array((), dtype=np.int64)
if isinstance(input[0], torch.Tensor):
if any(isinstance(x, torch.Tensor) for x in input):
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
return [convert_tensor_to_numpy(x) for x in input]
if isinstance(input[0], bool):
return np.array(input, dtype=np.bool_)
Expand All @@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:


def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Converts kwargs to be compatible with ONNX Runtime.

ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my knowledge, does ort support torch.bool now? or was this docstring outdated already?

Copy link
Collaborator Author

@justinchuby justinchuby Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't know if ORT supports bool (it should?), but I think this message was a mistake by copilot because we don't actually have this conversion logic as code. If we see issues with ORT I will make adjustments.

"""
"""Converts kwargs to be compatible with ONNX Runtime."""
new_kwargs = {}
for key, value in kwargs.items():
if key == "device":
Expand Down Expand Up @@ -473,6 +471,10 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
input.value = subarg
sequence_input.append(input)
ort_inputs[input_name] = subarg
else:
# Include non-numpy inputs as-is
# For example, it could be a None value that we want to keep
sequence_input.append(subarg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put comment here explaining why this is needed in case we forget?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider elif None? It's easier to catch what we are not expecting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I wonder if there are things we don't expect that can sneak in? Nested lists?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. Just in case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. Created #892 for now.

onnxscript_args.append(sequence_input)
else:
onnxscript_args.append(arg)
Expand Down Expand Up @@ -515,7 +517,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
# Make sure the model is valid
try:
onnx.checker.check_model(onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
) from e
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def _where_input_wrangler(
TorchLibOpInfo("gt", core_ops.aten_gt),
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
Expand Down