diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7787f66da2..be8dd62d09 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4035,6 +4035,13 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy return op.Transpose(self, perm=perm) +@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) +def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: + new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0]) + new_indices = op.Squeeze(new_indices, axes=[1]) + return op.Gather(self, new_indices, axis=0) + + def aten_index_add( self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1 ) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 086264e9bf..bfe9f0f0eb 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -692,6 +692,31 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def _index_variable_bool(shape, max_indices, device): + if not isinstance(shape, tuple): + shape = (shape,) + index = ( + torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool() + ) + return index + + +def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + s = 5 + index_bool = _index_variable_bool(s, s, device=device) + test_args = [ + ([index_bool],), + ] + + for args in test_args: + yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) + + def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): del op_info # Unused del kwargs # Unused @@ -1933,6 +1958,15 @@ def __init__(self): ), sample_inputs_func=sample_inputs_index, ), + opinfo_core.OpInfo( + "ops.aten.index.Tensor.bool", + aten_name="index.Tensor", + dtypes=common_dtype.all_types_and_complex_and( + torch.bool, torch.float16, torch.bfloat16, torch.chalf + ), + sample_inputs_func=sample_inputs_index_bool, + op=torch.ops.aten.index.Tensor, + ), opinfo_core.OpInfo( "ops.aten.layer_norm", aten_name="layer_norm", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index ecf4a606d0..ddca9273e2 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -833,6 +833,7 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool,