Skip to content

Commit 9e88ffd

Browse files
authored
[torchlib] Simplify aten_trunc implementation
Simplify aten_trunc implementation according to onnx/onnx#4588 (comment)
1 parent dfee02e commit 9e88ffd

File tree

1 file changed

+2
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+2
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8385,11 +8385,8 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
83858385
@torch_op("aten::trunc")
83868386
def aten_trunc(self: TFloat) -> TFloat:
83878387
"""trunc(Tensor self) -> Tensor"""
8388-
8389-
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
8390-
integer_parts = op.Floor(op.Abs(self))
8391-
is_negative = op.Less(self, 0.0)
8392-
return op.Where(is_negative, op.Neg(integer_parts), integer_parts)
8388+
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591
8389+
return op.Floor(op.Abs(self)) * op.Sign(self)
83938390

83948391

83958392
@torch_op("aten::type_as", trace_only=True)

0 commit comments

Comments
 (0)