Skip to content

Commit df7d588

Browse files
justinchubybmehta001
authored andcommitted
[torchlib] Fix aten_div rounding_mode (microsoft#2147)
Fix microsoft#2144
1 parent b8a7671 commit df7d588

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,10 +2742,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
27422742
(
27432743
"aten::div.Tensor",
27442744
"aten::div.Scalar",
2745-
# When rounding_mode is None, performs a true division
2746-
# https://pytorch.org/docs/stable/generated/torch.div.html
2747-
"aten::div.Tensor_mode",
2748-
"aten::div.Scalar_mode",
27492745
"aten::divide.Tensor",
27502746
"aten::divide.Scalar",
27512747
"aten::true_divide.Tensor",
@@ -2799,41 +2795,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
27992795

28002796

28012797
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
2802-
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
2798+
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat:
28032799
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
28042800

2805-
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2806-
assert rounding_mode in {"trunc", "floor"}
2801+
assert rounding_mode in {"trunc", "floor", None}
28072802

28082803
if rounding_mode == "trunc":
28092804
# Rounds the results of the division towards zero.
28102805
# Equivalent to C-style integer division
2811-
result = aten_trunc(op.Div(self, other))
2812-
else: # rounding_mode == "floor"
2813-
result = op.Floor(op.Div(self, other))
2806+
return aten_trunc(op.Div(self, other))
2807+
if rounding_mode == "floor":
2808+
return op.Floor(op.Div(self, other))
28142809

2815-
return result
2810+
return op.Div(self, other)
28162811

28172812

28182813
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
2819-
def aten_div_mode_int(self: TInt, other: TInt, rounding_mode: str) -> TInt:
2814+
def aten_div_mode_int(
2815+
self: TInt, other: TInt, rounding_mode: Optional[str] = None
2816+
) -> TensorType:
28202817
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
28212818
28222819
Variant for integer inputs.
28232820
"""
2824-
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2825-
assert rounding_mode in {"trunc", "floor"}
2821+
assert rounding_mode in {"trunc", "floor", None}
28262822

28272823
quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
28282824

28292825
if rounding_mode == "trunc":
28302826
# Rounds the results of the division towards zero.
28312827
# Equivalent to C-style integer division
28322828
result = aten_trunc(quotient)
2833-
else: # rounding_mode == "floor"
2829+
return op.CastLike(result, self)
2830+
if rounding_mode == "floor":
28342831
result = op.Floor(quotient)
2832+
return op.CastLike(result, self)
28352833

2836-
return op.CastLike(result, self)
2834+
assert rounding_mode is None
2835+
# When rounding_mode is None, the return type is float32
2836+
return quotient
28372837

28382838

28392839
@torch_op("aten::dot")
@@ -8465,7 +8465,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
84658465
raise NotImplementedError()
84668466

84678467

8468-
@torch_op("aten::trunc")
8468+
@torch_op("aten::trunc", trace_only=True)
84698469
def aten_trunc(self: TFloat) -> TFloat:
84708470
"""trunc(Tensor self) -> Tensor"""
84718471

tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import onnxscript
3636
import onnxscript.evaluator
37+
import onnxscript.ir.passes.common.unused_removal
3738
from onnxscript import ir
3839
from onnxscript.function_libs.torch_lib import graph_building
3940
from tests.function_libs.torch_lib import error_reproduction
@@ -389,6 +390,19 @@ def _format_model_and_input_information(onnx_model, inputs):
389390
}
390391

391392

393+
def add_torchlib_common_imports(model: ir.Model) -> None:
394+
"""Hack to add torchlib common imports to the model."""
395+
396+
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
397+
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
398+
is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto())
399+
model.functions[rank_func.identifier()] = rank_func
400+
model.functions[is_scalar_func.identifier()] = is_scalar_func
401+
removal_pass = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()
402+
assert removal_pass.in_place
403+
removal_pass(model)
404+
405+
392406
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
393407
"""Checks if the dtype is compatible with the schema.
394408

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -778,10 +778,7 @@ def _where_input_wrangler(
778778
test_class_name="TestOutputConsistencyEager",
779779
reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989",
780780
),
781-
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int).skip(
782-
variant_name="no_rounding_mode",
783-
reason="this variation requires the rounding_mode argument",
784-
),
781+
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int),
785782
TorchLibOpInfo("dot", core_ops.aten_dot),
786783
TorchLibOpInfo(
787784
"empty",

0 commit comments

Comments
 (0)