Skip to content

Commit f49d83e

Browse files
committed
Implement aten::index | feat(torchlib) (#862)
WIP: need more tests Gather op: - https://github.com/openxla/xla/blob/main/docs/operation_semantics.md?rgh-link-date=2023-07-13T01%3A09%3A16Z#gather - https://www.pathpartnertech.com/gather-scatter-operation-in-deep-learning-framework/ --------- Co-authored-by: BowenBao <bowbaomicrosoft.com> ghstack-source-id: 078cf56 Pull Request resolved: #883 Signed-off-by: Justin Chu <[email protected]>
1 parent 8d3967d commit f49d83e

File tree

4 files changed

+156
-9
lines changed

4 files changed

+156
-9
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

+115-3
Original file line numberDiff line numberDiff line change
@@ -3020,10 +3020,122 @@ def aten_imag(self: TensorType) -> TensorType:
30203020
raise NotImplementedError()
30213021

30223022

3023-
def aten_index(self: TensorType, indices: Optional[Sequence[TensorType]]) -> TensorType:
3024-
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
3023+
def _are_consecutive(sorted_list: Sequence[int]) -> bool:
3024+
"""Returns True if a sorted list contains consecutive numbers."""
3025+
if not sorted_list:
3026+
return True
30253027

3026-
raise NotImplementedError()
3028+
return sorted_list == list(range(min(sorted_list), max(sorted_list) + 1))
3029+
3030+
3031+
def _has_none_in_middle(indices) -> bool:
3032+
"""Returns True if there is a None in the middle of the list."""
3033+
not_none_indices = [i for i, idx in enumerate(indices) if idx is not None]
3034+
return not _are_consecutive(not_none_indices)
3035+
3036+
3037+
def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
3038+
"""Returns the broadcasted shape of the given tensors."""
3039+
broadcasted = op.Max(*args)
3040+
return op.Shape(broadcasted)
3041+
3042+
3043+
@torch_op("aten::index", trace_only=True)
3044+
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
3045+
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
3046+
3047+
NOTE: Understanding `aten::index`
3048+
For `arg0` with shape `[7, 3, 4, 5, 6]`
3049+
The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
3050+
3051+
```
3052+
+> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
3053+
+> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
3054+
+> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
3055+
+> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
3056+
```
3057+
3058+
Here,
3059+
- `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
3060+
- The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
3061+
3062+
None in `indices` are like fillers for dimensions that cannot be removed in the process.
3063+
"""
3064+
3065+
self_rank = len(self.shape)
3066+
index_ranks = [len(index.shape) for index in indices if index is not None]
3067+
advanced_indexing_rank = max(index_ranks)
3068+
3069+
# reordered_positions is the permutation of the index positions where
3070+
# positions with None are move to the end of the list
3071+
# For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
3072+
reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i))
3073+
# Fill the list with the remaining indices up to the rank of the tensor self.
3074+
# For example, if indices = [None, 1, None, 2], and the rank of self is 6,
3075+
# then reordered_positions = [1, 3, 0, 2, 4, 5]
3076+
reordered_positions = [
3077+
*reordered_positions,
3078+
*range(len(reordered_positions), self_rank),
3079+
]
3080+
# Transpose self according to the reordered positions
3081+
self = op.Transpose(self, perm=reordered_positions)
3082+
3083+
# Broadcast the indices to the same shape then concatenate
3084+
not_none_indices = [idx for idx in indices if idx is not None]
3085+
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
3086+
final_index = op.Concat(
3087+
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
3088+
axis=-1,
3089+
)
3090+
3091+
self = op.GatherND(self, final_index, batch_dims=0)
3092+
3093+
if _has_none_in_middle(indices):
3094+
# If there is None in the middle, Advanced Indexing cannot decide where to put
3095+
# the new dimensions. So it places them in the front, like GatherND does.
3096+
return self
3097+
3098+
# When the indices are consecutive, Advanced Indexing will place the new dimensions
3099+
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
3100+
#
3101+
# Input index axes (three parts):
3102+
# [
3103+
# x_None_front_1, ... x_None_front_m,
3104+
# x1, ..., xk,
3105+
# x_None_back_1, ..., x_None_back_m
3106+
# ]
3107+
# GatherND result axes:
3108+
# [
3109+
# *broadcasted_shape(x1, x2, ..., xk),
3110+
# x_None_front_1, ... x_None_front_m,
3111+
# x_None_back_1, ..., x_None_back_m
3112+
# ]
3113+
# (Transpose here)
3114+
# Advanced indexing result axes:
3115+
# [
3116+
# x_None_front_1, ... x_None_front_m,
3117+
# *brocasted_shape(x1, x2, ..., xk),
3118+
# x_None_back_1, ..., x_None_back_m
3119+
# ]
3120+
#
3121+
# Need to transpose the result of GatherND to match this axes ordering.
3122+
first_not_none_position = reordered_positions[0] # x_None_front_m + 1
3123+
starting_position_of_none_in_back = (
3124+
advanced_indexing_rank + first_not_none_position
3125+
) # x_None_back_1
3126+
result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank
3127+
perm = [
3128+
*range(
3129+
advanced_indexing_rank, starting_position_of_none_in_back
3130+
), # None_front_1...x_None_back_1
3131+
*range(0, advanced_indexing_rank), # 0...len(broadcasted_shape)
3132+
*range(
3133+
starting_position_of_none_in_back,
3134+
result_rank,
3135+
), # None_back_1...None_back_m
3136+
]
3137+
3138+
return op.Transpose(self, perm=perm)
30273139

30283140

30293141
def aten_index_add(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

+34
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,31 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
444444
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs)
445445

446446

447+
def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
448+
del op_info # Unused
449+
del kwargs # Unused
450+
make_arg = functools.partial(
451+
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
452+
)
453+
s = 5
454+
index_1d = common_methods_invocations.index_variable(2, s, device=device)
455+
index_2d = common_methods_invocations.index_variable((2, s+1), s, device=device)
456+
index_3d = common_methods_invocations.index_variable((2, s+1, s+2), s, device=device)
457+
test_args = [
458+
([index_1d],),
459+
([None, index_1d],),
460+
([None, None, None, index_1d],),
461+
([index_1d, None],),
462+
([index_1d, None, None],),
463+
([None, index_1d, None, index_1d],),
464+
([index_1d, None, index_1d, None],),
465+
([None, index_1d, index_1d, None],),
466+
]
467+
468+
for args in test_args:
469+
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)
470+
471+
447472
def sample_inputs_native_dropout(
448473
op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs
449474
):
@@ -616,6 +641,15 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
616641
sample_inputs_func=sample_inputs_convolution,
617642
supports_out=False,
618643
),
644+
opinfo_core.OpInfo(
645+
"aten.index.Tensor",
646+
dtypes=common_dtype.all_types_and_complex_and(
647+
torch.bool, torch.float16, torch.bfloat16, torch.chalf
648+
),
649+
aten_name="index",
650+
op=torch.ops.aten.index.Tensor,
651+
sample_inputs_func=sample_inputs_index,
652+
),
619653
opinfo_core.OpInfo(
620654
"ops.aten.layer_norm",
621655
aten_name="layer_norm",

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any:
261261
if isinstance(input, (tuple, list)):
262262
if len(input) == 0:
263263
return np.array((), dtype=np.int64)
264-
if isinstance(input[0], torch.Tensor):
264+
if any(isinstance(x, torch.Tensor) for x in input):
265+
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc.
265266
return [convert_tensor_to_numpy(x) for x in input]
266267
if isinstance(input[0], bool):
267268
return np.array(input, dtype=np.bool_)
@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any:
276277

277278

278279
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
279-
"""Converts kwargs to be compatible with ONNX Runtime.
280-
281-
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
282-
"""
280+
"""Converts kwargs to be compatible with ONNX Runtime."""
283281
new_kwargs = {}
284282
for key, value in kwargs.items():
285283
if key == "device":
@@ -473,6 +471,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
473471
input.value = subarg
474472
sequence_input.append(input)
475473
ort_inputs[input_name] = subarg
474+
else:
475+
sequence_input.append(subarg)
476476
onnxscript_args.append(sequence_input)
477477
else:
478478
onnxscript_args.append(arg)
@@ -515,7 +515,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
515515
# Make sure the model is valid
516516
try:
517517
onnx.checker.check_model(onnx_model, full_check=True)
518-
except onnx.checker.ValidationError as e:
518+
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
519519
raise AssertionError(
520520
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
521521
) from e

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def _where_input_wrangler(
607607
TorchLibOpInfo("gt", core_ops.aten_gt),
608608
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
609609
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
610+
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index),
610611
TorchLibOpInfo(
611612
"index_put_bool",
612613
core_ops.aten_index_put_bool,

0 commit comments

Comments
 (0)