From eb03a135da40f86de4be1f3d776121c8c1180cef Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 7 Aug 2023 22:25:19 +0000 Subject: [PATCH 01/15] WIP --- onnxscript/function_libs/torch_lib/ops/core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 811b27463d..6af35c7318 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2190,13 +2190,19 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType raise NotImplementedError() -@torch_op(("aten::div", "aten::div.Tensor")) +@torch_op(("aten::div", "aten::div.Tensor", "aten::div.Scalar")) def aten_div(self: TFloat, other: TFloat) -> TFloat: """div.Tensor(Tensor self, Tensor other) -> Tensor""" # Int inputs will be promoted to float by PyTorch return op.Div(self, other) +@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), rounding_mode: str) +def aten_div(self: TFloat, other: TFloat) -> TFloat: + """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" + + pass + def aten_divide(self: TensorType, other: TensorType) -> TensorType: """divide.Tensor(Tensor self, Tensor other) -> Tensor""" From 4d98de57bb7acd6d34c1c9935b2941d6721d38f1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 7 Aug 2023 23:18:39 +0000 Subject: [PATCH 02/15] Snapshot --- .../function_libs/torch_lib/ops/core.py | 39 ++++++++++++------- .../function_libs/torch_lib/ops_test_data.py | 15 ++++--- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6af35c7318..48c9542c5c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2190,24 +2190,41 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType raise NotImplementedError() -@torch_op(("aten::div", "aten::div.Tensor", "aten::div.Scalar")) +@torch_op( + ( + "aten::div", + "aten::div.Tensor", + "aten::div.Scalar", + # When rounding_mode is None, performs a true division + # https://pytorch.org/docs/stable/generated/torch.div.html + "aten::div.Tensor_mode", + "aten::div.Scalar_mode", + "aten::divide", + "aten::true_divide", + ) +) def aten_div(self: TFloat, other: TFloat) -> TFloat: """div.Tensor(Tensor self, Tensor other) -> Tensor""" # Int inputs will be promoted to float by PyTorch return op.Div(self, other) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), rounding_mode: str) -def aten_div(self: TFloat, other: TFloat) -> TFloat: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" - pass +@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) +def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: + """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" + # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison + assert rounding_mode in {"trunc", "floor"} -def aten_divide(self: TensorType, other: TensorType) -> TensorType: - """divide.Tensor(Tensor self, Tensor other) -> Tensor""" + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(op.Div(self, other)) + else: # rounding_mode == "floor" + result = op.Floor(op.Div(self, other)) - raise NotImplementedError() + return result @torch_op("aten::dot") @@ -6924,12 +6941,6 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType: raise NotImplementedError() -def aten_true_divide(self: TensorType, other: TensorType) -> TensorType: - """true_divide.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::trunc") def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """trunc(Tensor self) -> Tensor""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 99d2ce686e..709385b16a 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -629,12 +629,16 @@ def _where_input_wrangler( TorchLibOpInfo("cosh", core_ops.aten_cosh), TorchLibOpInfo("cross", core_ops.aten_cross), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB - TorchLibOpInfo( - "div", - core_ops.aten_div, - ).skip( + TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, - reason="rounding_mode is not yet supported", + reason="this variation does not take the rounding_mode argument", + ), + TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True).skip( + matcher=lambda sample: sample.kwargs.get("rounding_mode") is None, + reason="this variation requires the rounding_mode argument", + # ).xfail( + # dtypes=(torch.float16,), + # reason="fixme: division" ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( @@ -1838,6 +1842,7 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) From 2ca50d5a3c6a7210b86c90c4a204427ef47ddf68 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 00:13:53 +0000 Subject: [PATCH 03/15] Cast to float --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 48c9542c5c..e762352db2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2217,13 +2217,17 @@ def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison assert rounding_mode in {"trunc", "floor"} + # Cast inputs to float to preserve numerical precision + self_float = op.Cast(self, to=FLOAT.dtype) + other_float = op.Cast(other, to=FLOAT.dtype) if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(op.Div(self, other)) + result = aten_trunc(op.Div(self_float, other_float)) else: # rounding_mode == "floor" - result = op.Floor(op.Div(self, other)) + result = op.Floor(op.Div(self_float, other_float)) + result = op.CastLike(result, self) return result From 571c076c19df01b82f0619aaa9e6ff3622e4b9b5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 16:40:07 +0000 Subject: [PATCH 04/15] floor_divide --- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++-- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e762352db2..d48638aa9f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2773,10 +2773,11 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Floor(self) -def aten_floor_divide(self: TensorType, other: TensorType) -> TensorType: +@torch_op("aten::floor_divide") +def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Floor(op.Div(self, other)) def aten_fmax(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 709385b16a..fe4d4c4a44 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -662,6 +662,7 @@ def _where_input_wrangler( TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler), TorchLibOpInfo("floor", core_ops.aten_floor), + TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("full", core_ops.aten_full), TorchLibOpInfo( From 325b4b1c23e6821766403787a6d567e23cc2ae65 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 17:19:44 +0000 Subject: [PATCH 05/15] Create mismatch report --- .../torch_lib/error_reproduction.py | 81 ++++++++++++++++++- .../tests/function_libs/torch_lib/ops_test.py | 11 ++- .../function_libs/torch_lib/ops_test_data.py | 3 - 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py index 11ae464927..900f4f4819 100644 --- a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py +++ b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py @@ -82,6 +82,45 @@ """ +_MISMATCH_MARKDOWN_TEMPLATE = """\ +### Summary + +The output of ONNX Runtime does not match that of PyTorch when executing test +`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`. + +To recreate this report, use + +```bash +CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name} +``` + +### Inputs + +```python +inputs = {inputs} +kwargs = {kwargs} +``` + +### Expected output + +```python +expected = {expected} +``` + +### Actual output + +```python +actual = {actual} +``` + +### Full error stack + +``` +{error_stack} +``` +""" + + def create_reproduction_report( test_name: str, onnx_model: onnx.ModelProto, @@ -123,9 +162,43 @@ def create_reproduction_report( # Turn test name into a valid file name markdown_file_name = f'{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_path = save_error_report(markdown_file_name, markdown) + print(f"Created reproduction report at {markdown_file_path}") + + +def create_mismatch_report( + test_name: str, + sample_num: int, + inputs, + kwargs, + actual, + expected, + error: Exception, +) -> None: + error_text = str(error) + error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) + short_test_name = test_name.split(".")[-1] + markdown = _MISMATCH_MARKDOWN_TEMPLATE.format( + test_name=test_name, + short_test_name=short_test_name, + sample_num=sample_num, + inputs=inputs, + kwargs=kwargs, + expected=expected, + actual=actual, + error_stack=error_stack, + ) + + markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_path = save_error_report(markdown_file_name, markdown) + print(f"Created reproduction report at {markdown_file_path}") + + +def save_error_report(file_name: str, text: str): reports_dir = pathlib.Path("error_reports") reports_dir.mkdir(parents=True, exist_ok=True) - markdown_file_path = reports_dir / markdown_file_name - with open(markdown_file_path, "w", encoding="utf-8") as f: - f.write(markdown) - print(f"Created reproduction report at {markdown_file_path}") + file_path = reports_dir / file_name + with open(file_path, "w", encoding="utf-8") as f: + f.write(text) + + return file_path diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index c840707880..fbb3bfb73f 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -22,6 +22,7 @@ """ from __future__ import annotations +import os import unittest from typing import Callable, Optional, Sequence, Tuple @@ -36,7 +37,11 @@ import onnxscript import onnxscript.evaluator -from onnxscript.tests.function_libs.torch_lib import ops_test_common, ops_test_data +from onnxscript.tests.function_libs.torch_lib import ( + error_reproduction, + ops_test_common, + ops_test_data, +) # All dtypes will be tested on the generated symbolic functions. # complex64 will be flattened to float32. @@ -260,6 +265,10 @@ def run_test_output_match( check_device=False, ) except AssertionError as e: + if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": + error_reproduction.create_mismatch_report( + test_name, j, inputs, cpu_sample.kwargs, actual, expected, e + ) if len(flattened_torch_outputs) > 1: raise AssertionError(f"Output {j} mismatch") from e raise diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index fe4d4c4a44..50ead1ce37 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -636,9 +636,6 @@ def _where_input_wrangler( TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is None, reason="this variation requires the rounding_mode argument", - # ).xfail( - # dtypes=(torch.float16,), - # reason="fixme: division" ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( From 53dcd0d554189d65166b9f163887f7bb72ea4d17 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 17:47:02 +0000 Subject: [PATCH 06/15] Diff --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++------ .../function_libs/torch_lib/error_reproduction.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d48638aa9f..3a1eebe876 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2217,17 +2217,13 @@ def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison assert rounding_mode in {"trunc", "floor"} - # Cast inputs to float to preserve numerical precision - self_float = op.Cast(self, to=FLOAT.dtype) - other_float = op.Cast(other, to=FLOAT.dtype) if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(op.Div(self_float, other_float)) + result = aten_trunc(op.Div(self, other)) else: # rounding_mode == "floor" - result = op.Floor(op.Div(self_float, other_float)) + result = op.Floor(op.Div(self, other)) - result = op.CastLike(result, self) return result diff --git a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py index 900f4f4819..926128c8ab 100644 --- a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py +++ b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py @@ -6,6 +6,7 @@ import time import traceback from typing import Any, Mapping +import difflib import numpy as np import onnx @@ -113,6 +114,12 @@ actual = {actual} ``` +### Difference + +```diff +{diff} +``` + ### Full error stack ``` @@ -178,6 +185,13 @@ def create_mismatch_report( error_text = str(error) error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) short_test_name = test_name.split(".")[-1] + diff = difflib.unified_diff( + str(actual).splitlines(), + str(expected).splitlines(), + tofile="actual", + fromfile="expected", + lineterm="", + ) markdown = _MISMATCH_MARKDOWN_TEMPLATE.format( test_name=test_name, short_test_name=short_test_name, @@ -186,6 +200,7 @@ def create_mismatch_report( kwargs=kwargs, expected=expected, actual=actual, + diff="\n".join(diff), error_stack=error_stack, ) From 1aad280860dc2476747171330576822b1c6a009d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 18:11:28 +0000 Subject: [PATCH 07/15] Display shapes --- .../torch_lib/error_reproduction.py | 21 ++++++++++++++++--- .../tests/function_libs/torch_lib/ops_test.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py index 926128c8ab..3f6cd4b3c3 100644 --- a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py +++ b/onnxscript/tests/function_libs/torch_lib/error_reproduction.py @@ -1,12 +1,12 @@ from __future__ import annotations +import difflib import pathlib import platform import sys import time import traceback from typing import Any, Mapping -import difflib import numpy as np import onnx @@ -97,6 +97,8 @@ ### Inputs +Shapes: `{input_shapes}` + ```python inputs = {inputs} kwargs = {kwargs} @@ -108,12 +110,16 @@ expected = {expected} ``` +Shape: `{expected_shape}` + ### Actual output ```python actual = {actual} ``` +Shape: `{actual_shape}` + ### Difference ```diff @@ -188,18 +194,27 @@ def create_mismatch_report( diff = difflib.unified_diff( str(actual).splitlines(), str(expected).splitlines(), - tofile="actual", - fromfile="expected", + fromfile="actual", + tofile="expected", lineterm="", ) + input_shapes = repr( + [ + f"Tensor<{inp.shape}, dtype={inp.dtype}>" if isinstance(inp, torch.Tensor) else inp + for inp in inputs + ] + ) markdown = _MISMATCH_MARKDOWN_TEMPLATE.format( test_name=test_name, short_test_name=short_test_name, sample_num=sample_num, + input_shapes=input_shapes, inputs=inputs, kwargs=kwargs, expected=expected, + expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None, actual=actual, + actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None, diff="\n".join(diff), error_stack=error_stack, ) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index fbb3bfb73f..76c39b8c73 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -267,7 +267,7 @@ def run_test_output_match( except AssertionError as e: if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": error_reproduction.create_mismatch_report( - test_name, j, inputs, cpu_sample.kwargs, actual, expected, e + test_name, i, inputs, cpu_sample.kwargs, actual, expected, e ) if len(flattened_torch_outputs) > 1: raise AssertionError(f"Output {j} mismatch") from e From acbec9497297944877cc3b5e475072212f420860 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 18:45:17 +0000 Subject: [PATCH 08/15] Skip tests --- .../function_libs/torch_lib/ops_test_data.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 50ead1ce37..8f428787d9 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -633,9 +633,21 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", ), - TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True).skip( + TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) + .skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is None, reason="this variation requires the rounding_mode argument", + ) + .xfail( + variant_name="trunc_rounding", + dtypes=(torch.float16,), + reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", + ) + .xfail( + variant_name="floor_rounding", + dtypes=(torch.float16,), + test_class_name="TestOutputConsistencyEager", + reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989", ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( @@ -659,7 +671,11 @@ def _where_input_wrangler( TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler), TorchLibOpInfo("floor", core_ops.aten_floor), - TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide), + TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail( + dtypes=(torch.float16,), + test_class_name="TestOutputConsistencyEager", + reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989", + ), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("full", core_ops.aten_full), TorchLibOpInfo( From 17117228457169c6d55daa6c958ab1c7a0c5b8eb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 18:54:58 +0000 Subject: [PATCH 09/15] IS_MACOS --- onnxscript/tests/function_libs/torch_lib/ops_test_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index 7a2e5b4040..c5c51a18cc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -8,6 +8,7 @@ import multiprocessing import os import pprint +import sys import unittest import warnings from typing import ( @@ -57,6 +58,7 @@ TEST_OPSET_VERSION = 18 IS_WINDOWS = os.name == "nt" +IS_MACOS = sys.platform == "darwin" @dataclasses.dataclass From 8af4056ff75de987369c25f93fbeb50acdda6b74 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 18:56:57 +0000 Subject: [PATCH 10/15] Skip MacOS --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 8f428787d9..412a037465 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -641,6 +641,8 @@ def _where_input_wrangler( .xfail( variant_name="trunc_rounding", dtypes=(torch.float16,), + # Numbers match on MacOS + enabled_if=not ops_test_common.IS_MACOS, reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ) .xfail( From 761ed779a5077a8697ddd6c22e9fae6db4cd2277 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 19:04:59 +0000 Subject: [PATCH 11/15] mac --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 412a037465..f3f6a3b589 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -641,7 +641,13 @@ def _where_input_wrangler( .xfail( variant_name="trunc_rounding", dtypes=(torch.float16,), - # Numbers match on MacOS + enabled_if=not ops_test_common.IS_MACOS, + reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", + ) + .skip( + variant_name="trunc_rounding", + dtypes=(torch.float16,), + # Numbers match on MacOS sometimes but not other times enabled_if=not ops_test_common.IS_MACOS, reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ) From 35410ef24b54823fd494ee6fa9624389cccf20ff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 19:05:18 +0000 Subject: [PATCH 12/15] macos --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index f3f6a3b589..99187a86fb 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -648,7 +648,7 @@ def _where_input_wrangler( variant_name="trunc_rounding", dtypes=(torch.float16,), # Numbers match on MacOS sometimes but not other times - enabled_if=not ops_test_common.IS_MACOS, + enabled_if=ops_test_common.IS_MACOS, reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ) .xfail( From 0657ab39388388ce19ae6abe691deb8866097861 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 19:23:47 +0000 Subject: [PATCH 13/15] skip --- .../tests/function_libs/torch_lib/ops_test_common.py | 2 -- .../tests/function_libs/torch_lib/ops_test_data.py | 9 +-------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index c5c51a18cc..7a2e5b4040 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -8,7 +8,6 @@ import multiprocessing import os import pprint -import sys import unittest import warnings from typing import ( @@ -58,7 +57,6 @@ TEST_OPSET_VERSION = 18 IS_WINDOWS = os.name == "nt" -IS_MACOS = sys.platform == "darwin" @dataclasses.dataclass diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 99187a86fb..3198381108 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -647,15 +647,8 @@ def _where_input_wrangler( .skip( variant_name="trunc_rounding", dtypes=(torch.float16,), - # Numbers match on MacOS sometimes but not other times - enabled_if=ops_test_common.IS_MACOS, + # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", - ) - .xfail( - variant_name="floor_rounding", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989", ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( From 597b3c81a52f1bb47f88a48f6d213c4611628a92 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 19:28:10 +0000 Subject: [PATCH 14/15] IS_MACOS --- .../tests/function_libs/torch_lib/ops_test_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 3198381108..b54c3f97e0 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -638,17 +638,17 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("rounding_mode") is None, reason="this variation requires the rounding_mode argument", ) - .xfail( - variant_name="trunc_rounding", - dtypes=(torch.float16,), - enabled_if=not ops_test_common.IS_MACOS, - reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", - ) .skip( variant_name="trunc_rounding", dtypes=(torch.float16,), # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", + ) + .xfail( + variant_name="floor_rounding", + dtypes=(torch.float16,), + test_class_name="TestOutputConsistencyEager", + reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989", ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( From ea0158ce5d6fd7b264edc11888aeec82e568382a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Aug 2023 19:30:36 +0000 Subject: [PATCH 15/15] div_mode --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index b54c3f97e0..f8996e416e 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -635,7 +635,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) .skip( - matcher=lambda sample: sample.kwargs.get("rounding_mode") is None, + variant_name="no_rounding_mode", reason="this variation requires the rounding_mode argument", ) .skip(