Skip to content

Add support for aten:index op when index is boolean | feat(torchlib) #1285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 5, 2024
Merged
7 changes: 7 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,6 +4035,13 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
return op.Transpose(self, perm=perm)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0])
new_indices = op.Squeeze(new_indices, axes=[1])
return op.Gather(self, new_indices, axis=0)


def aten_index_add(
self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
) -> TensorType:
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,31 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
)


def _index_variable_bool(shape, max_indices, device):
if not isinstance(shape, tuple):
shape = (shape,)
index = (
torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool()
)
return index


def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
s = 5
index_bool = _index_variable_bool(s, s, device=device)
test_args = [
([index_bool],),
]

for args in test_args:
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
Expand Down Expand Up @@ -1933,6 +1958,15 @@ def __init__(self):
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.index.Tensor.bool",
aten_name="index.Tensor",
dtypes=common_dtype.all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16, torch.chalf
),
sample_inputs_func=sample_inputs_index_bool,
op=torch.ops.aten.index.Tensor,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
Expand Down