Skip to content

Commit ba255f7

Browse files
authored
Implement aten::div.Tensor_mode | feat(torchlib) (#988)
`aten::div.Tensor_mode` is implemented with two ONNX functions. When `rounding_mode` is `None`, we use `aten_div`. When it is not None, we use `aten_div_mode`. This way we don't need to handle when `rounding_mode==None` in `aten_div`. For `float16` inputs we need to cast to float32 to preserve precision. Otherwise `-inf` sometimes becomes `inf` in the output. - Additionally registers aliases "aten::divide", "aten::true_divide" to `aten_div`. - Supports saving mismatches to error reports - xfail and documents off-by-one errors with float16 (#990, #989) Fixes #980
1 parent c6e216e commit ba255f7

File tree

4 files changed

+172
-22
lines changed

4 files changed

+172
-22
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,18 +2190,41 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
21902190
raise NotImplementedError()
21912191

21922192

2193-
@torch_op(("aten::div", "aten::div.Tensor"))
2193+
@torch_op(
2194+
(
2195+
"aten::div",
2196+
"aten::div.Tensor",
2197+
"aten::div.Scalar",
2198+
# When rounding_mode is None, performs a true division
2199+
# https://pytorch.org/docs/stable/generated/torch.div.html
2200+
"aten::div.Tensor_mode",
2201+
"aten::div.Scalar_mode",
2202+
"aten::divide",
2203+
"aten::true_divide",
2204+
)
2205+
)
21942206
def aten_div(self: TFloat, other: TFloat) -> TFloat:
21952207
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""
21962208

21972209
# Int inputs will be promoted to float by PyTorch
21982210
return op.Div(self, other)
21992211

22002212

2201-
def aten_divide(self: TensorType, other: TensorType) -> TensorType:
2202-
"""divide.Tensor(Tensor self, Tensor other) -> Tensor"""
2213+
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
2214+
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
2215+
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
22032216

2204-
raise NotImplementedError()
2217+
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2218+
assert rounding_mode in {"trunc", "floor"}
2219+
2220+
if rounding_mode == "trunc":
2221+
# Rounds the results of the division towards zero.
2222+
# Equivalent to C-style integer division
2223+
result = aten_trunc(op.Div(self, other))
2224+
else: # rounding_mode == "floor"
2225+
result = op.Floor(op.Div(self, other))
2226+
2227+
return result
22052228

22062229

22072230
@torch_op("aten::dot")
@@ -2746,10 +2769,11 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
27462769
return op.Floor(self)
27472770

27482771

2749-
def aten_floor_divide(self: TensorType, other: TensorType) -> TensorType:
2772+
@torch_op("aten::floor_divide")
2773+
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
27502774
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
27512775

2752-
raise NotImplementedError()
2776+
return op.Floor(op.Div(self, other))
27532777

27542778

27552779
def aten_fmax(self: TensorType, other: TensorType) -> TensorType:
@@ -6918,12 +6942,6 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
69186942
raise NotImplementedError()
69196943

69206944

6921-
def aten_true_divide(self: TensorType, other: TensorType) -> TensorType:
6922-
"""true_divide.Tensor(Tensor self, Tensor other) -> Tensor"""
6923-
6924-
raise NotImplementedError()
6925-
6926-
69276945
@torch_op("aten::trunc")
69286946
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
69296947
"""trunc(Tensor self) -> Tensor"""

onnxscript/tests/function_libs/torch_lib/error_reproduction.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import difflib
34
import pathlib
45
import platform
56
import sys
@@ -82,6 +83,57 @@
8283
"""
8384

8485

86+
_MISMATCH_MARKDOWN_TEMPLATE = """\
87+
### Summary
88+
89+
The output of ONNX Runtime does not match that of PyTorch when executing test
90+
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
91+
92+
To recreate this report, use
93+
94+
```bash
95+
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
96+
```
97+
98+
### Inputs
99+
100+
Shapes: `{input_shapes}`
101+
102+
```python
103+
inputs = {inputs}
104+
kwargs = {kwargs}
105+
```
106+
107+
### Expected output
108+
109+
```python
110+
expected = {expected}
111+
```
112+
113+
Shape: `{expected_shape}`
114+
115+
### Actual output
116+
117+
```python
118+
actual = {actual}
119+
```
120+
121+
Shape: `{actual_shape}`
122+
123+
### Difference
124+
125+
```diff
126+
{diff}
127+
```
128+
129+
### Full error stack
130+
131+
```
132+
{error_stack}
133+
```
134+
"""
135+
136+
85137
def create_reproduction_report(
86138
test_name: str,
87139
onnx_model: onnx.ModelProto,
@@ -123,9 +175,60 @@ def create_reproduction_report(
123175

124176
# Turn test name into a valid file name
125177
markdown_file_name = f'{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
178+
markdown_file_path = save_error_report(markdown_file_name, markdown)
179+
print(f"Created reproduction report at {markdown_file_path}")
180+
181+
182+
def create_mismatch_report(
183+
test_name: str,
184+
sample_num: int,
185+
inputs,
186+
kwargs,
187+
actual,
188+
expected,
189+
error: Exception,
190+
) -> None:
191+
error_text = str(error)
192+
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
193+
short_test_name = test_name.split(".")[-1]
194+
diff = difflib.unified_diff(
195+
str(actual).splitlines(),
196+
str(expected).splitlines(),
197+
fromfile="actual",
198+
tofile="expected",
199+
lineterm="",
200+
)
201+
input_shapes = repr(
202+
[
203+
f"Tensor<{inp.shape}, dtype={inp.dtype}>" if isinstance(inp, torch.Tensor) else inp
204+
for inp in inputs
205+
]
206+
)
207+
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
208+
test_name=test_name,
209+
short_test_name=short_test_name,
210+
sample_num=sample_num,
211+
input_shapes=input_shapes,
212+
inputs=inputs,
213+
kwargs=kwargs,
214+
expected=expected,
215+
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
216+
actual=actual,
217+
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
218+
diff="\n".join(diff),
219+
error_stack=error_stack,
220+
)
221+
222+
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
223+
markdown_file_path = save_error_report(markdown_file_name, markdown)
224+
print(f"Created reproduction report at {markdown_file_path}")
225+
226+
227+
def save_error_report(file_name: str, text: str):
126228
reports_dir = pathlib.Path("error_reports")
127229
reports_dir.mkdir(parents=True, exist_ok=True)
128-
markdown_file_path = reports_dir / markdown_file_name
129-
with open(markdown_file_path, "w", encoding="utf-8") as f:
130-
f.write(markdown)
131-
print(f"Created reproduction report at {markdown_file_path}")
230+
file_path = reports_dir / file_name
231+
with open(file_path, "w", encoding="utf-8") as f:
232+
f.write(text)
233+
234+
return file_path

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
from __future__ import annotations
2424

25+
import os
2526
import unittest
2627
from typing import Callable, Optional, Sequence, Tuple
2728

@@ -36,7 +37,11 @@
3637

3738
import onnxscript
3839
import onnxscript.evaluator
39-
from onnxscript.tests.function_libs.torch_lib import ops_test_common, ops_test_data
40+
from onnxscript.tests.function_libs.torch_lib import (
41+
error_reproduction,
42+
ops_test_common,
43+
ops_test_data,
44+
)
4045

4146
# All dtypes will be tested on the generated symbolic functions.
4247
# complex64 will be flattened to float32.
@@ -260,6 +265,10 @@ def run_test_output_match(
260265
check_device=False,
261266
)
262267
except AssertionError as e:
268+
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
269+
error_reproduction.create_mismatch_report(
270+
test_name, i, inputs, cpu_sample.kwargs, actual, expected, e
271+
)
263272
if len(flattened_torch_outputs) > 1:
264273
raise AssertionError(f"Output {j} mismatch") from e
265274
raise

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,26 @@ def _where_input_wrangler(
629629
TorchLibOpInfo("cosh", core_ops.aten_cosh),
630630
TorchLibOpInfo("cross", core_ops.aten_cross),
631631
# TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB
632-
TorchLibOpInfo(
633-
"div",
634-
core_ops.aten_div,
635-
).skip(
632+
TorchLibOpInfo("div", core_ops.aten_div).skip(
636633
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
637-
reason="rounding_mode is not yet supported",
634+
reason="this variation does not take the rounding_mode argument",
635+
),
636+
TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True)
637+
.skip(
638+
variant_name="no_rounding_mode",
639+
reason="this variation requires the rounding_mode argument",
640+
)
641+
.skip(
642+
variant_name="trunc_rounding",
643+
dtypes=(torch.float16,),
644+
# Numbers match sometimes but not other times
645+
reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990",
646+
)
647+
.xfail(
648+
variant_name="floor_rounding",
649+
dtypes=(torch.float16,),
650+
test_class_name="TestOutputConsistencyEager",
651+
reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989",
638652
),
639653
TorchLibOpInfo("dot", core_ops.aten_dot),
640654
TorchLibOpInfo(
@@ -658,6 +672,11 @@ def _where_input_wrangler(
658672
TorchLibOpInfo("fill", core_ops.aten_fill),
659673
TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler),
660674
TorchLibOpInfo("floor", core_ops.aten_floor),
675+
TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail(
676+
dtypes=(torch.float16,),
677+
test_class_name="TestOutputConsistencyEager",
678+
reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989",
679+
),
661680
TorchLibOpInfo("fmod", core_ops.aten_fmod),
662681
TorchLibOpInfo("full", core_ops.aten_full),
663682
TorchLibOpInfo(
@@ -1838,6 +1857,7 @@ def _where_input_wrangler(
18381857
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))
18391858
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
18401859
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
1860+
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
18411861
ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",))
18421862
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
18431863
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))

0 commit comments

Comments
 (0)