Skip to content

Commit aa93917

Browse files
authored
feat(atenlib): op(trunc) (#515)
aten::trunc ``` < domain: "onnxscript.atenlib", opset_import: ["" : 18] > aten_trunc (self) => (return_val) { tmp = Abs (self) integer_parts = Floor (tmp) tmp_0 = Constant <value = float tmp_0 {0}> () tmp_0_cast = CastLike (tmp_0, self) is_negative = Less (self, tmp_0_cast) tmp_1 = Neg (integer_parts) return_val = Where (is_negative, tmp_1, integer_parts) } ```
1 parent 510db96 commit aa93917

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5573,10 +5573,14 @@ def aten_true_divide(self: TensorType, other: TensorType) -> TensorType:
55735573
raise NotImplementedError()
55745574

55755575

5576-
def aten_trunc(self: TensorType) -> TensorType:
5576+
@torch_op("aten::trunc")
5577+
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
55775578
"""trunc(Tensor self) -> Tensor"""
55785579

5579-
raise NotImplementedError()
5580+
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
5581+
integer_parts = op.Floor(op.Abs(self))
5582+
is_negative = op.Less(self, 0.0)
5583+
return op.Where(is_negative, op.Neg(integer_parts), integer_parts)
55805584

55815585

55825586
def aten_type_as(self: TensorType, other: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def _where_input_wrangler(
389389
"tan": core_ops.aten_tan,
390390
"tanh": core_ops.aten_tanh,
391391
"topk": core_ops.aten_topk,
392+
"trunc": core_ops.aten_trunc,
392393
"unsqueeze": core_ops.aten_unsqueeze,
393394
"view": core_ops.aten_view,
394395
"where": (core_ops.aten_where, _where_input_wrangler),

0 commit comments

Comments
 (0)