Skip to content

Commit b8b4cb1

Browse files
committed
Remove hacks
1 parent 176cb57 commit b8b4cb1

File tree

1 file changed

+4
-74
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-74
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8380,19 +8380,8 @@ def aten__unique(
83808380
) -> tuple[TensorType, TensorType]:
83818381
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""
83828382

8383-
unique_values, indices, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
8384-
# HACK: force indices to be in the graph so that it gets a name during optimization
8385-
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8386-
# We don't need to worry about unique_values since it is a required output.
8387-
indices_size = op.Shape(indices)
8388-
indices_numel = op.ReduceProd(indices_size, keepdims=False)
8383+
unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
83898384
input_size = op.Shape(self)
8390-
# force inverse_indices to depend on indices through input_size
8391-
if indices_numel != 0:
8392-
input_size = input_size * indices_numel
8393-
input_size = input_size / indices_numel
8394-
else:
8395-
input_size = input_size + indices_numel
83968385
if return_inverse:
83978386
inverse_indices = op.Reshape(inverse_indices, input_size)
83988387
else:
@@ -8413,24 +8402,8 @@ def aten__unique2(
84138402
) -> tuple[TensorType, TensorType, TensorType]:
84148403
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
84158404

8416-
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
8417-
# HACK: force indices and inverse_indices to be in the graph so
8418-
# that they get names during optimization.
8419-
# counts must depend on indices and inverse_indices,
8420-
# and inverse_indices must depend on indices
8421-
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8422-
# We don't have to worry about unique_values because it is a required output.
8423-
indices_size = op.Shape(indices)
8424-
indices_numel = op.ReduceProd(indices_size, keepdims=False)
8425-
inverse_indices_size = op.Shape(inverse_indices)
8426-
inverse_indices_numel = op.ReduceProd(inverse_indices_size, keepdims=False)
8405+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
84278406
input_size = op.Shape(self)
8428-
# force inverse_indices to depend on indices through input_size
8429-
if indices_numel != 0:
8430-
input_size = input_size * indices_numel
8431-
input_size = input_size / indices_numel
8432-
else:
8433-
input_size = input_size + indices_numel
84348407
if return_inverse:
84358408
inverse_indices = op.Reshape(inverse_indices, input_size)
84368409
else:
@@ -8439,21 +8412,8 @@ def aten__unique2(
84398412
inverse_indices = op.Reshape(inverse_indices, input_size)
84408413
else:
84418414
inverse_indices = op.ConstantOfShape([0], value=[0])
8442-
if return_counts:
8443-
# force counts to depend on inverse_indices through indices_size
8444-
if inverse_indices_numel != 0:
8445-
indices_size = indices_size * inverse_indices_numel
8446-
indices_size = indices_size / inverse_indices_numel
8447-
else:
8448-
indices_size = indices_size + inverse_indices_numel
8449-
# force counts to depend on indices
8450-
counts = op.Reshape(counts, indices_size)
8451-
else:
8415+
if not return_counts:
84528416
counts = op.ConstantOfShape([0], value=[0])
8453-
# force counts to depend on indices
8454-
counts = counts * indices_numel
8455-
# force counts to depend on inverse_indices
8456-
counts = counts * inverse_indices_numel
84578417
return unique_values, inverse_indices, counts
84588418

84598419

@@ -8467,47 +8427,17 @@ def aten_unique_dim(
84678427
) -> tuple[TensorType, TensorType, TensorType]:
84688428
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
84698429

8470-
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
8471-
# HACK: force indices and inverse_indices to be in the graph so
8472-
# that they get names during optimization.
8473-
# counts must depend on indices and inverse_indices,
8474-
# and inverse_indices must depend on indices
8475-
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8476-
# We don't have to worry about unique_values because it is a required output.
8477-
indices_size = op.Shape(indices)
8478-
indices_numel = op.ReduceProd(indices_size, keepdims=False)
8479-
inverse_indices_size = op.Shape(inverse_indices)
8480-
inverse_indices_numel = op.ReduceProd(inverse_indices_size, keepdims=False)
8430+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
84818431
if return_inverse:
84828432
input_size = op.Shape(self)
8483-
# force inverse_indices to depend on indices through input_size
8484-
if indices_numel != 0:
8485-
input_size = input_size * indices_numel
8486-
input_size = input_size / indices_numel
8487-
else:
8488-
input_size = input_size + indices_numel
84898433
inverse_indices = op.Reshape(inverse_indices, op.Reshape(input_size[dim], [-1]))
84908434
else:
84918435
inverse_indices = op.ConstantOfShape([0], value=[0])
8492-
# force inverse_indices to depend on indices
8493-
inverse_indices = inverse_indices * indices_numel
84948436
if return_counts:
8495-
# force dependence on inverse_indices through indices_size
8496-
if inverse_indices_numel != 0:
8497-
indices_size = indices_size * inverse_indices_numel
8498-
indices_size = indices_size / inverse_indices_numel
8499-
else:
8500-
indices_size = indices_size + inverse_indices_numel
8501-
# force dependence on indices
8502-
counts = op.Reshape(counts, indices_size)
85038437
output_size = op.Shape(unique_values)
85048438
counts = op.Reshape(counts, op.Reshape(output_size[dim], [-1]))
85058439
else:
85068440
counts = op.ConstantOfShape([0], value=[0])
8507-
# force dependence on indices
8508-
counts = counts * indices_numel
8509-
# force dependence on inverse_indices
8510-
counts = counts * inverse_indices_numel
85118441
return unique_values, inverse_indices, counts
85128442

85138443

0 commit comments

Comments
 (0)