@@ -8380,19 +8380,8 @@ def aten__unique(
8380
8380
) -> tuple [TensorType , TensorType ]:
8381
8381
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""
8382
8382
8383
- unique_values , indices , inverse_indices , _ = op .Unique (self , axis = None , sorted = True )
8384
- # HACK: force indices to be in the graph so that it gets a name during optimization
8385
- # Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8386
- # We don't need to worry about unique_values since it is a required output.
8387
- indices_size = op .Shape (indices )
8388
- indices_numel = op .ReduceProd (indices_size , keepdims = False )
8383
+ unique_values , _ , inverse_indices , _ = op .Unique (self , axis = None , sorted = True )
8389
8384
input_size = op .Shape (self )
8390
- # force inverse_indices to depend on indices through input_size
8391
- if indices_numel != 0 :
8392
- input_size = input_size * indices_numel
8393
- input_size = input_size / indices_numel
8394
- else :
8395
- input_size = input_size + indices_numel
8396
8385
if return_inverse :
8397
8386
inverse_indices = op .Reshape (inverse_indices , input_size )
8398
8387
else :
@@ -8413,24 +8402,8 @@ def aten__unique2(
8413
8402
) -> tuple [TensorType , TensorType , TensorType ]:
8414
8403
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8415
8404
8416
- unique_values , indices , inverse_indices , counts = op .Unique (self , axis = None , sorted = True )
8417
- # HACK: force indices and inverse_indices to be in the graph so
8418
- # that they get names during optimization.
8419
- # counts must depend on indices and inverse_indices,
8420
- # and inverse_indices must depend on indices
8421
- # Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8422
- # We don't have to worry about unique_values because it is a required output.
8423
- indices_size = op .Shape (indices )
8424
- indices_numel = op .ReduceProd (indices_size , keepdims = False )
8425
- inverse_indices_size = op .Shape (inverse_indices )
8426
- inverse_indices_numel = op .ReduceProd (inverse_indices_size , keepdims = False )
8405
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = None , sorted = True )
8427
8406
input_size = op .Shape (self )
8428
- # force inverse_indices to depend on indices through input_size
8429
- if indices_numel != 0 :
8430
- input_size = input_size * indices_numel
8431
- input_size = input_size / indices_numel
8432
- else :
8433
- input_size = input_size + indices_numel
8434
8407
if return_inverse :
8435
8408
inverse_indices = op .Reshape (inverse_indices , input_size )
8436
8409
else :
@@ -8439,21 +8412,8 @@ def aten__unique2(
8439
8412
inverse_indices = op .Reshape (inverse_indices , input_size )
8440
8413
else :
8441
8414
inverse_indices = op .ConstantOfShape ([0 ], value = [0 ])
8442
- if return_counts :
8443
- # force counts to depend on inverse_indices through indices_size
8444
- if inverse_indices_numel != 0 :
8445
- indices_size = indices_size * inverse_indices_numel
8446
- indices_size = indices_size / inverse_indices_numel
8447
- else :
8448
- indices_size = indices_size + inverse_indices_numel
8449
- # force counts to depend on indices
8450
- counts = op .Reshape (counts , indices_size )
8451
- else :
8415
+ if not return_counts :
8452
8416
counts = op .ConstantOfShape ([0 ], value = [0 ])
8453
- # force counts to depend on indices
8454
- counts = counts * indices_numel
8455
- # force counts to depend on inverse_indices
8456
- counts = counts * inverse_indices_numel
8457
8417
return unique_values , inverse_indices , counts
8458
8418
8459
8419
@@ -8467,47 +8427,17 @@ def aten_unique_dim(
8467
8427
) -> tuple [TensorType , TensorType , TensorType ]:
8468
8428
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8469
8429
8470
- unique_values , indices , inverse_indices , counts = op .Unique (self , axis = dim , sorted = True )
8471
- # HACK: force indices and inverse_indices to be in the graph so
8472
- # that they get names during optimization.
8473
- # counts must depend on indices and inverse_indices,
8474
- # and inverse_indices must depend on indices
8475
- # Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8476
- # We don't have to worry about unique_values because it is a required output.
8477
- indices_size = op .Shape (indices )
8478
- indices_numel = op .ReduceProd (indices_size , keepdims = False )
8479
- inverse_indices_size = op .Shape (inverse_indices )
8480
- inverse_indices_numel = op .ReduceProd (inverse_indices_size , keepdims = False )
8430
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = dim , sorted = True )
8481
8431
if return_inverse :
8482
8432
input_size = op .Shape (self )
8483
- # force inverse_indices to depend on indices through input_size
8484
- if indices_numel != 0 :
8485
- input_size = input_size * indices_numel
8486
- input_size = input_size / indices_numel
8487
- else :
8488
- input_size = input_size + indices_numel
8489
8433
inverse_indices = op .Reshape (inverse_indices , op .Reshape (input_size [dim ], [- 1 ]))
8490
8434
else :
8491
8435
inverse_indices = op .ConstantOfShape ([0 ], value = [0 ])
8492
- # force inverse_indices to depend on indices
8493
- inverse_indices = inverse_indices * indices_numel
8494
8436
if return_counts :
8495
- # force dependence on inverse_indices through indices_size
8496
- if inverse_indices_numel != 0 :
8497
- indices_size = indices_size * inverse_indices_numel
8498
- indices_size = indices_size / inverse_indices_numel
8499
- else :
8500
- indices_size = indices_size + inverse_indices_numel
8501
- # force dependence on indices
8502
- counts = op .Reshape (counts , indices_size )
8503
8437
output_size = op .Shape (unique_values )
8504
8438
counts = op .Reshape (counts , op .Reshape (output_size [dim ], [- 1 ]))
8505
8439
else :
8506
8440
counts = op .ConstantOfShape ([0 ], value = [0 ])
8507
- # force dependence on indices
8508
- counts = counts * indices_numel
8509
- # force dependence on inverse_indices
8510
- counts = counts * inverse_indices_numel
8511
8441
return unique_values , inverse_indices , counts
8512
8442
8513
8443
0 commit comments