@@ -8591,16 +8591,84 @@ def aten_unique_consecutive(
8591
8591
raise NotImplementedError ()
8592
8592
8593
8593
8594
+ @torch_op ("aten::_unique" , trace_only = True )
8595
+ def aten__unique (
8596
+ self : TensorType ,
8597
+ sorted : bool = True , # pylint: disable=unused-argument
8598
+ return_inverse : bool = False ,
8599
+ ) -> tuple [TensorType , TensorType ]:
8600
+ """_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""
8601
+
8602
+ unique_values , _ , inverse_indices , _ = op .Unique (self , axis = None , sorted = True )
8603
+ input_size = op .Shape (self )
8604
+ if return_inverse :
8605
+ inverse_indices = op .Reshape (inverse_indices , input_size )
8606
+ else :
8607
+ input_numel = op .ReduceProd (input_size , keepdims = False )
8608
+ if input_numel == 0 :
8609
+ inverse_indices = op .Reshape (inverse_indices , input_size )
8610
+ else :
8611
+ inverse_indices = op .ConstantOfShape ([0 ])
8612
+ inverse_indices = op .Cast (inverse_indices , to = INT64 .dtype )
8613
+ return unique_values , inverse_indices
8614
+
8615
+
8616
+ @torch_op ("aten::_unique2" , trace_only = True )
8617
+ def aten__unique2 (
8618
+ self : TensorType ,
8619
+ sorted : bool = True , # pylint: disable=unused-argument
8620
+ return_inverse : bool = False ,
8621
+ return_counts : bool = False ,
8622
+ ) -> tuple [TensorType , TensorType , TensorType ]:
8623
+ """_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8624
+
8625
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = None , sorted = True )
8626
+ input_size = op .Shape (self )
8627
+ if return_inverse :
8628
+ inverse_indices = op .Reshape (inverse_indices , input_size )
8629
+ else :
8630
+ input_numel = op .ReduceProd (input_size , keepdims = False )
8631
+ if input_numel == 0 :
8632
+ inverse_indices = op .Reshape (inverse_indices , input_size )
8633
+ else :
8634
+ inverse_indices = op .ConstantOfShape ([0 ])
8635
+ inverse_indices = op .Cast (inverse_indices , to = INT64 .dtype )
8636
+ if not return_counts :
8637
+ counts = op .ConstantOfShape ([0 ])
8638
+ counts = op .Cast (counts , to = INT64 .dtype )
8639
+ return unique_values , inverse_indices , counts
8640
+
8641
+
8642
+ @torch_op ("aten::unique_dim" , trace_only = True )
8594
8643
def aten_unique_dim (
8595
8644
self : TensorType ,
8596
8645
dim : int ,
8597
- sorted : bool = True ,
8646
+ sorted : bool = True , # pylint: disable=unused-argument
8598
8647
return_inverse : bool = False ,
8599
8648
return_counts : bool = False ,
8600
8649
) -> tuple [TensorType , TensorType , TensorType ]:
8601
8650
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8602
8651
8603
- raise NotImplementedError ()
8652
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = dim , sorted = True )
8653
+ input_size = op .Shape (self )
8654
+ # Normalize dim to be non-negative
8655
+ input_ndim = op .Max (op .Size (input_size ), op .Constant (value_ints = [1 ]))
8656
+ dim = op .Mod (dim , input_ndim )
8657
+ if return_inverse :
8658
+ inverse_indices = op .Reshape (
8659
+ inverse_indices ,
8660
+ op .Reshape (op .Slice (input_size , dim , dim + 1 ), op .Constant (value_ints = [- 1 ])),
8661
+ )
8662
+ else :
8663
+ inverse_indices = op .ConstantOfShape ([0 ])
8664
+ inverse_indices = op .Cast (inverse_indices , to = INT64 .dtype )
8665
+ if return_counts :
8666
+ output_size = op .Shape (unique_values )
8667
+ counts = op .Reshape (counts , op .Reshape (op .Slice (output_size , dim , dim + 1 ), [- 1 ]))
8668
+ else :
8669
+ counts = op .ConstantOfShape ([0 ])
8670
+ counts = op .Cast (counts , to = INT64 .dtype )
8671
+ return unique_values , inverse_indices , counts
8604
8672
8605
8673
8606
8674
def aten_unique_dim_consecutive (
0 commit comments