Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,6 +4035,13 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
return op.Transpose(self, perm=perm)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0])
new_indices = op.Squeeze(new_indices, axes=[1])
return op.Gather(self, new_indices, axis=0)


def aten_index_add(
self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
) -> TensorType:
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,31 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
)


def _index_variable_bool(shape, max_indices, device):
if not isinstance(shape, tuple):
shape = (shape,)
index = (
torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool()
)
return index


def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
s = 5
index_bool = _index_variable_bool(s, s, device=device)
test_args = [
([index_bool],),
]

for args in test_args:
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
del op_info # Unused
del kwargs # Unused
Expand Down Expand Up @@ -1933,6 +1958,15 @@ def __init__(self):
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.index.Tensor.bool",
aten_name="index.Tensor",
dtypes=common_dtype.all_types_and_complex_and(
torch.bool, torch.float16, torch.bfloat16, torch.chalf
),
sample_inputs_func=sample_inputs_index_bool,
op=torch.ops.aten.index.Tensor,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
Expand Down