Skip to content

Commit 4c1cda2

Browse files
Add unique op (#1547)
Add support for exporting `torch.unique` following the conclusion of pytorch/pytorch#113118. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent ddce766 commit 4c1cda2

File tree

3 files changed

+132
-2
lines changed

3 files changed

+132
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8591,16 +8591,84 @@ def aten_unique_consecutive(
85918591
raise NotImplementedError()
85928592

85938593

8594+
@torch_op("aten::_unique", trace_only=True)
8595+
def aten__unique(
8596+
self: TensorType,
8597+
sorted: bool = True, # pylint: disable=unused-argument
8598+
return_inverse: bool = False,
8599+
) -> tuple[TensorType, TensorType]:
8600+
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""
8601+
8602+
unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
8603+
input_size = op.Shape(self)
8604+
if return_inverse:
8605+
inverse_indices = op.Reshape(inverse_indices, input_size)
8606+
else:
8607+
input_numel = op.ReduceProd(input_size, keepdims=False)
8608+
if input_numel == 0:
8609+
inverse_indices = op.Reshape(inverse_indices, input_size)
8610+
else:
8611+
inverse_indices = op.ConstantOfShape([0])
8612+
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
8613+
return unique_values, inverse_indices
8614+
8615+
8616+
@torch_op("aten::_unique2", trace_only=True)
8617+
def aten__unique2(
8618+
self: TensorType,
8619+
sorted: bool = True, # pylint: disable=unused-argument
8620+
return_inverse: bool = False,
8621+
return_counts: bool = False,
8622+
) -> tuple[TensorType, TensorType, TensorType]:
8623+
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8624+
8625+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
8626+
input_size = op.Shape(self)
8627+
if return_inverse:
8628+
inverse_indices = op.Reshape(inverse_indices, input_size)
8629+
else:
8630+
input_numel = op.ReduceProd(input_size, keepdims=False)
8631+
if input_numel == 0:
8632+
inverse_indices = op.Reshape(inverse_indices, input_size)
8633+
else:
8634+
inverse_indices = op.ConstantOfShape([0])
8635+
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
8636+
if not return_counts:
8637+
counts = op.ConstantOfShape([0])
8638+
counts = op.Cast(counts, to=INT64.dtype)
8639+
return unique_values, inverse_indices, counts
8640+
8641+
8642+
@torch_op("aten::unique_dim", trace_only=True)
85948643
def aten_unique_dim(
85958644
self: TensorType,
85968645
dim: int,
8597-
sorted: bool = True,
8646+
sorted: bool = True, # pylint: disable=unused-argument
85988647
return_inverse: bool = False,
85998648
return_counts: bool = False,
86008649
) -> tuple[TensorType, TensorType, TensorType]:
86018650
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
86028651

8603-
raise NotImplementedError()
8652+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
8653+
input_size = op.Shape(self)
8654+
# Normalize dim to be non-negative
8655+
input_ndim = op.Max(op.Size(input_size), op.Constant(value_ints=[1]))
8656+
dim = op.Mod(dim, input_ndim)
8657+
if return_inverse:
8658+
inverse_indices = op.Reshape(
8659+
inverse_indices,
8660+
op.Reshape(op.Slice(input_size, dim, dim + 1), op.Constant(value_ints=[-1])),
8661+
)
8662+
else:
8663+
inverse_indices = op.ConstantOfShape([0])
8664+
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
8665+
if return_counts:
8666+
output_size = op.Shape(unique_values)
8667+
counts = op.Reshape(counts, op.Reshape(op.Slice(output_size, dim, dim + 1), [-1]))
8668+
else:
8669+
counts = op.ConstantOfShape([0])
8670+
counts = op.Cast(counts, to=INT64.dtype)
8671+
return unique_values, inverse_indices, counts
86048672

86058673

86068674
def aten_unique_dim_consecutive(

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,35 @@ def shape(size, rank, with_batch_channel=True):
19501950
)
19511951

19521952

1953+
def sample_inputs__unique(op_info, device, dtype, requires_grad, **kwargs):
1954+
for sample in common_methods_invocations.sample_inputs_unique(
1955+
op_info, device, dtype, requires_grad, **kwargs
1956+
):
1957+
return_counts = sample.kwargs.pop("return_counts", None)
1958+
dim = sample.kwargs.pop("dim", None)
1959+
# take only those samples that do not ask for counts or a dim
1960+
if not return_counts and dim is None:
1961+
yield sample
1962+
1963+
1964+
def sample_inputs__unique2(op_info, device, dtype, requires_grad, **kwargs):
1965+
for sample in common_methods_invocations.sample_inputs_unique(
1966+
op_info, device, dtype, requires_grad, **kwargs
1967+
):
1968+
# take only those samples that do not ask for a dim
1969+
if sample.kwargs.pop("dim", None) is None:
1970+
yield sample
1971+
1972+
1973+
def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs):
1974+
for sample in common_methods_invocations.sample_inputs_unique(
1975+
op_info, device, dtype, requires_grad, **kwargs
1976+
):
1977+
# take only those samples that ask for a dim
1978+
if sample.kwargs.get("dim") is not None:
1979+
yield sample
1980+
1981+
19531982
def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs):
19541983
del op_info
19551984
del kwargs
@@ -2504,6 +2533,30 @@ def __init__(self):
25042533
sample_inputs_func=sample_inputs_unfold,
25052534
supports_out=False,
25062535
),
2536+
opinfo_core.OpInfo(
2537+
"ops.aten._unique.default",
2538+
aten_name="_unique.default",
2539+
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
2540+
sample_inputs_func=sample_inputs__unique,
2541+
supports_out=False,
2542+
supports_autograd=False,
2543+
),
2544+
opinfo_core.OpInfo(
2545+
"ops.aten._unique2.default",
2546+
aten_name="_unique2.default",
2547+
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
2548+
sample_inputs_func=sample_inputs__unique2,
2549+
supports_out=False,
2550+
supports_autograd=False,
2551+
),
2552+
opinfo_core.OpInfo(
2553+
"ops.aten.unique_dim.default",
2554+
aten_name="unique_dim.default",
2555+
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
2556+
sample_inputs_func=sample_inputs_unique_dim,
2557+
supports_out=False,
2558+
supports_autograd=False,
2559+
),
25072560
opinfo_core.OpInfo(
25082561
"ops.aten.upsample_bicubic2d.default",
25092562
aten_name="upsample_bicubic2d",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,15 @@ def _where_input_wrangler(
20682068
), # Custom from extra_opinfo
20692069
TorchLibOpInfo("transpose", core_ops.aten_transpose),
20702070
TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True),
2071+
TorchLibOpInfo("ops.aten._unique.default", core_ops.aten__unique),
2072+
TorchLibOpInfo("ops.aten._unique2.default", core_ops.aten__unique2),
2073+
TorchLibOpInfo("ops.aten.unique_dim.default", core_ops.aten_unique_dim).skip(
2074+
device_type="cpu",
2075+
reason=(
2076+
"ops.aten.unique_dim.default returns different shapes for optional outputs on CPU/CUDA. "
2077+
"Our implementation is based on that for CUDA"
2078+
),
2079+
),
20712080
TorchLibOpInfo(
20722081
"ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}
20732082
),

0 commit comments

Comments
 (0)