Skip to content

Commit ad6faf2

Browse files
authored
Revert "Add support for aten:index op when index is boolean | feat(torchlib)" (#1307)
Reverting because this causes the dispatcher in PyTorch to choose the wrong overload. Reverts #1285
1 parent 9b1f2c6 commit ad6faf2

File tree

3 files changed

+0
-42
lines changed

3 files changed

+0
-42
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4035,13 +4035,6 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy
40354035
return op.Transpose(self, perm=perm)
40364036

40374037

4038-
@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
4039-
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:
4040-
new_indices = op.Transpose(op.NonZero(indices[0]), perm=[1, 0])
4041-
new_indices = op.Squeeze(new_indices, axes=[1])
4042-
return op.Gather(self, new_indices, axis=0)
4043-
4044-
40454038
def aten_index_add(
40464039
self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
40474040
) -> TensorType:

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -692,31 +692,6 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
692692
)
693693

694694

695-
def _index_variable_bool(shape, max_indices, device):
696-
if not isinstance(shape, tuple):
697-
shape = (shape,)
698-
index = (
699-
torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().bool()
700-
)
701-
return index
702-
703-
704-
def sample_inputs_index_bool(op_info, device, dtype, requires_grad, **kwargs):
705-
del op_info # Unused
706-
del kwargs # Unused
707-
make_arg = functools.partial(
708-
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
709-
)
710-
s = 5
711-
index_bool = _index_variable_bool(s, s, device=device)
712-
test_args = [
713-
([index_bool],),
714-
]
715-
716-
for args in test_args:
717-
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)
718-
719-
720695
def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
721696
del op_info # Unused
722697
del kwargs # Unused
@@ -1961,15 +1936,6 @@ def __init__(self):
19611936
),
19621937
sample_inputs_func=sample_inputs_index,
19631938
),
1964-
opinfo_core.OpInfo(
1965-
"ops.aten.index.Tensor.bool",
1966-
aten_name="index.Tensor",
1967-
dtypes=common_dtype.all_types_and_complex_and(
1968-
torch.bool, torch.float16, torch.bfloat16, torch.chalf
1969-
),
1970-
sample_inputs_func=sample_inputs_index_bool,
1971-
op=torch.ops.aten.index.Tensor,
1972-
),
19731939
opinfo_core.OpInfo(
19741940
"ops.aten.layer_norm",
19751941
aten_name="layer_norm",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,6 @@ def _where_input_wrangler(
848848
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
849849
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
850850
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True),
851-
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True),
852851
TorchLibOpInfo(
853852
"index_put_bool",
854853
core_ops.aten_index_put_bool,

0 commit comments

Comments
 (0)