Skip to content

Add unique op #1547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8591,16 +8591,84 @@
raise NotImplementedError()


@torch_op("aten::_unique", trace_only=True)
def aten__unique(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8609 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8609

Added line #L8609 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
return unique_values, inverse_indices


@torch_op("aten::_unique2", trace_only=True)
def aten__unique2(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8632 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8632

Added line #L8632 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
if not return_counts:
counts = op.ConstantOfShape([0])
counts = op.Cast(counts, to=INT64.dtype)
return unique_values, inverse_indices, counts


@torch_op("aten::unique_dim", trace_only=True)
def aten_unique_dim(
self: TensorType,
dim: int,
sorted: bool = True,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
input_size = op.Shape(self)

Check warning on line 8653 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8652-L8653

Added lines #L8652 - L8653 were not covered by tests
# Normalize dim to be non-negative
input_ndim = op.Max(op.Size(input_size), op.Constant(value_ints=[1]))
dim = op.Mod(dim, input_ndim)

Check warning on line 8656 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8655-L8656

Added lines #L8655 - L8656 were not covered by tests
if return_inverse:
inverse_indices = op.Reshape(

Check warning on line 8658 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8658

Added line #L8658 was not covered by tests
inverse_indices,
op.Reshape(op.Slice(input_size, dim, dim + 1), op.Constant(value_ints=[-1])),
)
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)

Check warning on line 8664 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8663-L8664

Added lines #L8663 - L8664 were not covered by tests
if return_counts:
output_size = op.Shape(unique_values)
counts = op.Reshape(counts, op.Reshape(op.Slice(output_size, dim, dim + 1), [-1]))

Check warning on line 8667 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8666-L8667

Added lines #L8666 - L8667 were not covered by tests
else:
counts = op.ConstantOfShape([0])
counts = op.Cast(counts, to=INT64.dtype)
return unique_values, inverse_indices, counts

Check warning on line 8671 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8669-L8671

Added lines #L8669 - L8671 were not covered by tests


def aten_unique_dim_consecutive(
Expand Down
53 changes: 53 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,35 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs__unique(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
return_counts = sample.kwargs.pop("return_counts", None)
dim = sample.kwargs.pop("dim", None)
# take only those samples that do not ask for counts or a dim
if not return_counts and dim is None:
yield sample


def sample_inputs__unique2(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
# take only those samples that do not ask for a dim
if sample.kwargs.pop("dim", None) is None:
yield sample


def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
# take only those samples that ask for a dim
if sample.kwargs.get("dim") is not None:
yield sample


def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -2504,6 +2533,30 @@ def __init__(self):
sample_inputs_func=sample_inputs_unfold,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._unique.default",
aten_name="_unique.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten._unique2.default",
aten_name="_unique2.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique2,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten.unique_dim.default",
aten_name="unique_dim.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs_unique_dim,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_bicubic2d.default",
aten_name="upsample_bicubic2d",
Expand Down
9 changes: 9 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,15 @@ def _where_input_wrangler(
), # Custom from extra_opinfo
TorchLibOpInfo("transpose", core_ops.aten_transpose),
TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True),
TorchLibOpInfo("ops.aten._unique.default", core_ops.aten__unique),
TorchLibOpInfo("ops.aten._unique2.default", core_ops.aten__unique2),
TorchLibOpInfo("ops.aten.unique_dim.default", core_ops.aten_unique_dim).skip(
device_type="cpu",
reason=(
"ops.aten.unique_dim.default returns different shapes for optional outputs on CPU/CUDA. "
"Our implementation is based on that for CUDA"
),
),
TorchLibOpInfo(
"ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}
),
Expand Down
Loading