diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 36f2a70f8c..eb276a239c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8533,6 +8533,14 @@ def aten_trunc(self: TFloat) -> TFloat: return op.Floor(op.Abs(self)) * op.Sign(self) +@torch_op("math::trunc", trace_only=True) +def python_math_trunc(self: TFloat) -> TInt: + """trunc(Tensor self) -> Tensor""" + # NOTE: This is used in SymInt/SymBool/SymFloat context, so + # we don't expect overflow to happen here. + return op.Cast(self, to=INT64.dtype) + + @torch_op("aten::type_as", trace_only=True) def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: """type_as(Tensor self, Tensor other) -> Tensor"""