Skip to content

Implement aten::div.Tensor_mode | feat(torchlib) #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 8, 2023
42 changes: 30 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,18 +2190,41 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
raise NotImplementedError()


@torch_op(("aten::div", "aten::div.Tensor"))
@torch_op(
(
"aten::div",
"aten::div.Tensor",
"aten::div.Scalar",
# When rounding_mode is None, performs a true division
# https://pytorch.org/docs/stable/generated/torch.div.html
Copy link
Contributor

@titaiwangms titaiwangms Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is dispatcher expected to filter any attribute with None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider this to be a better match I think? Any suggestions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think the dispatcher should strip None keyword args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think that makes sense. It's just we are altering attributes, and param_schema matching is diverged from the inputs/attributes sent into OnnxFunction. It's like there are many indications around dispatching/OnnxFunction param_schema. And it's not good for debugging.

Dispatcher alters inputs/attributes with hidden assumptions, but never return the altered inputs/attributes. So in OnnxFunction perspective, it runs directly on that dispatched function with attributes it doesn't need (won't error).

"aten::div.Tensor_mode",
"aten::div.Scalar_mode",
"aten::divide",
"aten::true_divide",
)
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

# Int inputs will be promoted to float by PyTorch
return op.Div(self, other)


def aten_divide(self: TensorType, other: TensorType) -> TensorType:
"""divide.Tensor(Tensor self, Tensor other) -> Tensor"""
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you recall what kind of attributes have default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str, int, float, bool attributes can have defaults I think. But I suppose any attributes should be able to have defaults with the attribute proto. Is this what you are asking?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a situation that two ONNX variants only differs on one default attribute. In that case, the dispatcher won't be able to dispatch it.

aten_op_attr(X, Y, attr="Good"):
    ...

aten_op(X, Y):
    ...

Copy link
Collaborator Author

@justinchuby justinchuby Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I would just hope/make sure that we don’t create variants like these.

I wonder if there is a way to test it. I think the matching logic you created can come in handy here. We can use that to test all variants registered in torchlib are not compatible with each other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dispatcher if we do see this case we can only pick one I suppose?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supposedly, if I pick any from them, there shouldn't be an issue, because they should be equal when it comes to no attr specified.

"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""

raise NotImplementedError()
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
assert rounding_mode in {"trunc", "floor"}

if rounding_mode == "trunc":
# Rounds the results of the division towards zero.
# Equivalent to C-style integer division
result = aten_trunc(op.Div(self, other))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will move to a common function when #834 is done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this. Could you share more about why we can use nested OnnxFunction now?

Copy link
Collaborator Author

@justinchuby justinchuby Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trace only function. So calling functions is fine. However we still do not like calling other aten functions. When we can have nested OnnxFunction calls, I will extract the trunc logic to a common function and call it from aten_trunc and this.

Right now I am doing this so aten_trunc doesn't become trace_only

else: # rounding_mode == "floor"
result = op.Floor(op.Div(self, other))

return result


@torch_op("aten::dot")
Expand Down Expand Up @@ -2746,10 +2769,11 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Floor(self)


def aten_floor_divide(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::floor_divide")
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()
return op.Floor(op.Div(self, other))


def aten_fmax(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -6918,12 +6942,6 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
raise NotImplementedError()


def aten_true_divide(self: TensorType, other: TensorType) -> TensorType:
"""true_divide.Tensor(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()


@torch_op("aten::trunc")
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""trunc(Tensor self) -> Tensor"""
Expand Down
111 changes: 107 additions & 4 deletions onnxscript/tests/function_libs/torch_lib/error_reproduction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import difflib
import pathlib
import platform
import sys
Expand Down Expand Up @@ -82,6 +83,57 @@
"""


_MISMATCH_MARKDOWN_TEMPLATE = """\
### Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.

To recreate this report, use

```bash
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
```

### Inputs

Shapes: `{input_shapes}`

```python
inputs = {inputs}
kwargs = {kwargs}
```

### Expected output

```python
expected = {expected}
```

Shape: `{expected_shape}`

### Actual output

```python
actual = {actual}
```

Shape: `{actual_shape}`

### Difference

```diff
{diff}
```

### Full error stack

```
{error_stack}
```
"""


def create_reproduction_report(
test_name: str,
onnx_model: onnx.ModelProto,
Expand Down Expand Up @@ -123,9 +175,60 @@ def create_reproduction_report(

# Turn test name into a valid file name
markdown_file_name = f'{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
markdown_file_path = save_error_report(markdown_file_name, markdown)
print(f"Created reproduction report at {markdown_file_path}")


def create_mismatch_report(
test_name: str,
sample_num: int,
inputs,
kwargs,
actual,
expected,
error: Exception,
) -> None:
error_text = str(error)
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
short_test_name = test_name.split(".")[-1]
diff = difflib.unified_diff(
str(actual).splitlines(),
str(expected).splitlines(),
fromfile="actual",
tofile="expected",
lineterm="",
)
input_shapes = repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>" if isinstance(inp, torch.Tensor) else inp
for inp in inputs
]
)
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
test_name=test_name,
short_test_name=short_test_name,
sample_num=sample_num,
input_shapes=input_shapes,
inputs=inputs,
kwargs=kwargs,
expected=expected,
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
actual=actual,
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
diff="\n".join(diff),
error_stack=error_stack,
)

markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
markdown_file_path = save_error_report(markdown_file_name, markdown)
print(f"Created reproduction report at {markdown_file_path}")


def save_error_report(file_name: str, text: str):
reports_dir = pathlib.Path("error_reports")
reports_dir.mkdir(parents=True, exist_ok=True)
markdown_file_path = reports_dir / markdown_file_name
with open(markdown_file_path, "w", encoding="utf-8") as f:
f.write(markdown)
print(f"Created reproduction report at {markdown_file_path}")
file_path = reports_dir / file_name
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)

return file_path
11 changes: 10 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
from __future__ import annotations

import os
import unittest
from typing import Callable, Optional, Sequence, Tuple

Expand All @@ -36,7 +37,11 @@

import onnxscript
import onnxscript.evaluator
from onnxscript.tests.function_libs.torch_lib import ops_test_common, ops_test_data
from onnxscript.tests.function_libs.torch_lib import (
error_reproduction,
ops_test_common,
ops_test_data,
)

# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
Expand Down Expand Up @@ -260,6 +265,10 @@ def run_test_output_match(
check_device=False,
)
except AssertionError as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_mismatch_report(
test_name, i, inputs, cpu_sample.kwargs, actual, expected, e
)
if len(flattened_torch_outputs) > 1:
raise AssertionError(f"Output {j} mismatch") from e
raise
Expand Down
30 changes: 25 additions & 5 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,26 @@ def _where_input_wrangler(
TorchLibOpInfo("cosh", core_ops.aten_cosh),
TorchLibOpInfo("cross", core_ops.aten_cross),
# TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB
TorchLibOpInfo(
"div",
core_ops.aten_div,
).skip(
TorchLibOpInfo("div", core_ops.aten_div).skip(
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
reason="rounding_mode is not yet supported",
reason="this variation does not take the rounding_mode argument",
),
TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True)
.skip(
variant_name="no_rounding_mode",
reason="this variation requires the rounding_mode argument",
)
.skip(
variant_name="trunc_rounding",
dtypes=(torch.float16,),
# Numbers match sometimes but not other times
reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990",
)
.xfail(
variant_name="floor_rounding",
dtypes=(torch.float16,),
test_class_name="TestOutputConsistencyEager",
reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989",
),
TorchLibOpInfo("dot", core_ops.aten_dot),
TorchLibOpInfo(
Expand All @@ -658,6 +672,11 @@ def _where_input_wrangler(
TorchLibOpInfo("fill", core_ops.aten_fill),
TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler),
TorchLibOpInfo("floor", core_ops.aten_floor),
TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail(
dtypes=(torch.float16,),
test_class_name="TestOutputConsistencyEager",
reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989",
),
TorchLibOpInfo("fmod", core_ops.aten_fmod),
TorchLibOpInfo("full", core_ops.aten_full),
TorchLibOpInfo(
Expand Down Expand Up @@ -1838,6 +1857,7 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",))
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
Expand Down