Skip to content

[torchlib] Fix aten_div rounding_mode #2147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,10 +2784,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
(
"aten::div.Tensor",
"aten::div.Scalar",
# When rounding_mode is None, performs a true division
# https://pytorch.org/docs/stable/generated/torch.div.html
"aten::div.Tensor_mode",
"aten::div.Scalar_mode",
"aten::divide.Tensor",
"aten::divide.Scalar",
"aten::true_divide.Tensor",
Expand Down Expand Up @@ -2842,41 +2838,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:


@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat:
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""

# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
assert rounding_mode in {"trunc", "floor"}
assert rounding_mode in {"trunc", "floor", None}

if rounding_mode == "trunc":
# Rounds the results of the division towards zero.
# Equivalent to C-style integer division
result = aten_trunc(op.Div(self, other))
else: # rounding_mode == "floor"
result = op.Floor(op.Div(self, other))
return aten_trunc(op.Div(self, other))
if rounding_mode == "floor":
return op.Floor(op.Div(self, other))

return result
return op.Div(self, other)


@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
def aten_div_mode_int(self: TInt, other: TInt, rounding_mode: str) -> TInt:
def aten_div_mode_int(
self: TInt, other: TInt, rounding_mode: Optional[str] = None
) -> TensorType:
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor

Variant for integer inputs.
"""
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
assert rounding_mode in {"trunc", "floor"}
assert rounding_mode in {"trunc", "floor", None}

quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))

if rounding_mode == "trunc":
# Rounds the results of the division towards zero.
# Equivalent to C-style integer division
result = aten_trunc(quotient)
else: # rounding_mode == "floor"
return op.CastLike(result, self)
if rounding_mode == "floor":
result = op.Floor(quotient)
return op.CastLike(result, self)

return op.CastLike(result, self)
assert rounding_mode is None
# When rounding_mode is None, the return type is float32
return quotient


@torch_op("aten::dot", trace_only=True)
Expand Down Expand Up @@ -8511,7 +8511,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
raise NotImplementedError()


@torch_op("aten::trunc")
@torch_op("aten::trunc", trace_only=True)
def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591
Expand Down
4 changes: 4 additions & 0 deletions tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import onnxscript
import onnxscript.evaluator
import onnxscript.ir.passes.common.unused_removal
from onnxscript import ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from tests.function_libs.torch_lib import error_reproduction
Expand Down Expand Up @@ -419,6 +420,9 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto())
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
removal_pass = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()
assert removal_pass.in_place
removal_pass(model)


def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
Expand Down
5 changes: 1 addition & 4 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,7 @@ def _where_input_wrangler(
# Numbers match sometimes but not other times
reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990",
),
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int).skip(
variant_name="no_rounding_mode",
reason="this variation requires the rounding_mode argument",
),
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int),
TorchLibOpInfo("dot", core_ops.aten_dot),
TorchLibOpInfo(
"empty",
Expand Down
Loading