Skip to content

Commit 6c5f9a8

Browse files
authored
feat(test): use ort 1.14 (#435)
Use ort 1.14 in test Fix errors in logsumexp and upsample_nearest2d and enable their tests Parameterized and refactored the backend tests.
1 parent 7cdd248 commit 6c5f9a8

File tree

12 files changed

+366
-322
lines changed

12 files changed

+366
-322
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ dmypy.json
9494
onnxscript/tests/models/testoutputs/*
9595
docs/auto_examples/*
9696
onnxscript/tests/mylib.onnxlib
97-
onnxscript/tests/onnx_backend_test_code
97+
**/onnx_backend_test_code/**
9898
*.onnxlib
9999
*.onnx

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"pyyaml",
2929
)
3030
ONNX = "onnx==1.13"
31-
ONNX_RUNTIME = "onnxruntime==1.13.1"
31+
ONNX_RUNTIME = "onnxruntime==1.14"
3232
PYTORCH = "torch==1.13"
3333

3434

onnxscript/backend/onnx_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
import textwrap
9+
from typing import Iterator
910

1011
import numpy
1112
import onnx
@@ -274,7 +275,7 @@ def to_python(self):
274275
return final
275276

276277

277-
def enumerate_onnx_tests(series, fct_filter=None):
278+
def enumerate_onnx_tests(series, fct_filter=None) -> Iterator[OnnxBackendTest]:
278279
"""Collects test from a sub folder of `onnx/backend/test`.
279280
Works as an enumerator to start processing them
280281
without waiting or storing too much of them.
@@ -285,7 +286,7 @@ def enumerate_onnx_tests(series, fct_filter=None):
285286
fct_filter: function `lambda testname: boolean` to load or skip
286287
the test, None for all
287288
288-
Returns:
289+
Yields:
289290
list of @see cl OnnxBackendTest
290291
"""
291292
root = os.path.dirname(backend_folder)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
import os
7+
import unittest
8+
9+
import onnxruntime as ort
10+
11+
from onnxscript.backend import onnx_backend
12+
13+
14+
def load_function(obj):
15+
return ort.InferenceSession(obj.SerializeToString())
16+
17+
18+
def run_function(obj, *inputs):
19+
names = [i.name for i in obj.get_inputs()]
20+
if len(names) < len(inputs):
21+
raise AssertionError(f"Got {len(inputs)} inputs but expecting {len(names)}.")
22+
feeds = {names[i]: inputs[i] for i in range(len(inputs))}
23+
got = obj.run(None, feeds)
24+
return got
25+
26+
27+
class TestOnnxBackEnd(unittest.TestCase):
28+
29+
folder = os.path.join(os.path.abspath(os.path.dirname(__file__)), "onnx_backend_test_code")
30+
31+
def test_enumerate_onnx_tests(self):
32+
name = "test_abs"
33+
code = list(onnx_backend.enumerate_onnx_tests("node", lambda folder: folder == name))
34+
self.assertEqual(len(code), 1)
35+
36+
def test_enumerate_onnx_tests_run_one(self):
37+
done = 0
38+
for backend_test in onnx_backend.enumerate_onnx_tests(
39+
"node", lambda folder: folder == "test_abs"
40+
):
41+
self.assertIn(backend_test.name, repr(backend_test))
42+
self.assertGreater(len(backend_test), 0)
43+
backend_test.run(load_function, run_function)
44+
done += 1
45+
self.assertEqual(done, 1)
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main(verbosity=2)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
from __future__ import annotations
6+
7+
import dataclasses
8+
import importlib
9+
import pathlib
10+
import re
11+
import unittest
12+
from typing import Pattern
13+
14+
import onnxruntime as ort
15+
import parameterized
16+
from onnxruntime.capi import onnxruntime_pybind11_state
17+
18+
import onnxscript
19+
from onnxscript.backend import onnx_backend, onnx_export
20+
from onnxscript.tests.models import type_double
21+
22+
23+
@dataclasses.dataclass
24+
class SkipInfo:
25+
pattern: Pattern
26+
reason: str
27+
condition: bool
28+
29+
30+
def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
31+
"""Create a SkipInfo object.
32+
33+
Args:
34+
pattern: A string or a regular expression to match the ONNX backend test name.
35+
reason: The reason why the test is skipped.
36+
condition: If False, the test is not skipped.
37+
"""
38+
if isinstance(pattern, str):
39+
pattern = re.compile(pattern)
40+
41+
return SkipInfo(pattern, reason, condition)
42+
43+
44+
SKIP_TESTS = (
45+
skip(r"_scan_", "Operator Scan is not supported by onnx-script"),
46+
skip(r"^test_scan", "Operator Scan is not supported by onnx-script"),
47+
skip(
48+
r"^test_split",
49+
"split has an undefined number of outputs. Current implementation of eager mode is not aware of them",
50+
),
51+
skip(
52+
r"^test_lstm_defaults",
53+
"LSTM has an undefined number of outputs. Current implementation of eager mode is not aware of them",
54+
),
55+
skip(
56+
r"^test_lstm_with_initial_bias",
57+
"LSTM has an undefined number of outputs. Current implementation of eager mode is not aware of them",
58+
),
59+
skip(
60+
r"^test_lstm_with_peepholes",
61+
"LSTM has an undefined number of outputs. Current implementation of eager mode is not aware of them",
62+
),
63+
skip(
64+
r"^test_optional_get_element_tensor",
65+
"ORT Unable to create onnxruntime InferenceSession for executing .OptionalGetElement op with onnx model",
66+
condition=ort.__version__ == "1.14.0",
67+
),
68+
skip(
69+
r"test_loop",
70+
"Change when the converter supports support something like 'while i < n and cond:'",
71+
),
72+
skip(
73+
r"^test_range_float_type_positive_delta_expanded",
74+
"Change when the converter supports support something like 'while i < n and cond:'",
75+
),
76+
skip(
77+
r"^test_range_int32_type_negative_delta_expanded",
78+
"Change when the converter supports support something like 'while i < n and cond:'",
79+
),
80+
)
81+
82+
83+
def load_function(obj):
84+
return ort.InferenceSession(obj.SerializeToString())
85+
86+
87+
def run_function(obj, *inputs):
88+
names = [i.name for i in obj.get_inputs()]
89+
if len(names) < len(inputs):
90+
raise AssertionError(f"Got {len(inputs)} inputs but expecting {len(names)}.")
91+
feeds = {names[i]: inputs[i] for i in range(len(inputs))}
92+
got = obj.run(None, feeds)
93+
return got
94+
95+
96+
def extract_functions(name: str, content: str, test_folder: pathlib.Path):
97+
"""Write the content into a file and import all OnnxFunctions from it."""
98+
if not test_folder.exists():
99+
test_folder.mkdir()
100+
init = test_folder / "__init__.py"
101+
init.touch()
102+
file = test_folder / f"{name}.py"
103+
file.write_text(content, encoding="utf-8")
104+
105+
import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}"
106+
try:
107+
mod = importlib.import_module(import_name)
108+
except (SyntaxError, ImportError) as e:
109+
raise AssertionError(
110+
f"Unable to import {import_name!r} (file: {file!r})\n----\n{content}"
111+
) from e
112+
functions = {
113+
k: v for k, v in mod.__dict__.items() if isinstance(v, onnxscript.OnnxFunction)
114+
}
115+
return functions
116+
117+
118+
def exec_main(f, *inputs):
119+
output = f(*inputs)
120+
if isinstance(output, tuple):
121+
return list(output)
122+
return [output]
123+
124+
125+
class TestOnnxBackEnd(unittest.TestCase):
126+
127+
test_folder = pathlib.Path(__file__).parent.parent / "tests" / "onnx_backend_test_code"
128+
129+
def test_export2python(self):
130+
proto = type_double.double_abs_subgraph.to_model_proto()
131+
code = onnx_export.export2python(proto, rename=True, use_operators=True)
132+
self.assertIn("v4 = v2 > v1", code)
133+
134+
@parameterized.parameterized.expand( # type: ignore[misc]
135+
[
136+
(backend_test.name, backend_test)
137+
for backend_test in onnx_backend.enumerate_onnx_tests("node")
138+
]
139+
)
140+
def test_export2python_produces_correct_onnx_script_model(
141+
self, _: str, backend_test: onnx_backend.OnnxBackendTest
142+
):
143+
for skip_info in SKIP_TESTS:
144+
if skip_info.pattern.match(backend_test.name) and skip_info.condition:
145+
self.skipTest(skip_info.reason)
146+
147+
self.assertIn(backend_test.name, repr(backend_test))
148+
self.assertGreater(len(backend_test), 0)
149+
try:
150+
backend_test.run(load_function, run_function)
151+
except NotImplementedError as e:
152+
self.skipTest(f"Not implemented {e}")
153+
except (
154+
IndexError,
155+
RuntimeError,
156+
TypeError,
157+
ValueError,
158+
AttributeError,
159+
onnxruntime_pybind11_state.Fail, # pylint: disable=c-extension-no-member
160+
onnxruntime_pybind11_state.NotImplemented, # pylint: disable=c-extension-no-member
161+
onnxruntime_pybind11_state.InvalidArgument, # pylint: disable=c-extension-no-member
162+
) as e:
163+
self.skipTest(f"Unable to load the model: {e}")
164+
except onnxruntime_pybind11_state.RuntimeException as e: # pylint: disable=c-extension-no-member
165+
self.skipTest(f"Unable to run the model: {e}")
166+
except AssertionError as e:
167+
self.skipTest(f"ORT result mismatches with the expected: {e}")
168+
169+
code = onnx_export.export2python(
170+
backend_test.onnx_model, function_name=f"bck_{backend_test.name}"
171+
)
172+
self.assertIn("@script()", code)
173+
self.assertIn(f"def bck_{backend_test.name}(", code)
174+
175+
if backend_test.name == "test_resize_downsample_scales_cubic":
176+
self.assertIn("Resize(X, None, scales,", code)
177+
178+
functions = extract_functions(backend_test.name, code, self.test_folder)
179+
main_function = functions[f"bck_{backend_test.name}"]
180+
self.assertIsNotNone(main_function)
181+
proto = main_function.to_model_proto()
182+
183+
# Opset may be different when an binary operator is used.
184+
if backend_test.onnx_model.ir_version != proto.ir_version:
185+
if (
186+
not backend_test.name.startswith( # pylint: disable=too-many-boolean-expressions
187+
"test_add"
188+
)
189+
and not backend_test.name.startswith("test_and")
190+
and not backend_test.name.startswith("test_div")
191+
and not backend_test.name.startswith("test_equal")
192+
and not backend_test.name.startswith("test_greater")
193+
and not backend_test.name.startswith("test_less")
194+
and not backend_test.name.startswith("test_matmul")
195+
and not backend_test.name.startswith("test_mod")
196+
and not backend_test.name.startswith("test_mul")
197+
and not backend_test.name.startswith("test_not")
198+
and not backend_test.name.startswith("test_or")
199+
and not backend_test.name.startswith("test_pow")
200+
and not backend_test.name.startswith("test_sub")
201+
and (backend_test.onnx_model.ir_version, proto.ir_version)
202+
not in {(3, 4), (5, 6)}
203+
):
204+
# Unexpected behavior for old opsets
205+
raise AssertionError(
206+
f"Incompatible ir_version {(backend_test.onnx_model.ir_version)} !="
207+
f" {(proto.ir_version)}\n"
208+
f"{backend_test.onnx_model}\n"
209+
f"-----\n"
210+
f"{proto}"
211+
)
212+
213+
try:
214+
session = ort.InferenceSession(proto.SerializeToString())
215+
except Exception as e:
216+
raise AssertionError(
217+
f"Unable to load onnx for test {backend_test.name!r}.\n"
218+
f"{onnxscript.proto2text(proto)}\n"
219+
f"-----\n"
220+
f"{backend_test.onnx_model}"
221+
) from e
222+
223+
# Check converted onnx
224+
def _load_function(_):
225+
return session
226+
227+
def _run_function(obj, *inputs):
228+
print(" run ONNX")
229+
for i, inp in enumerate(inputs):
230+
if inp is None:
231+
print(f" input {i}: None")
232+
else:
233+
print(
234+
f" input {i}: "
235+
f"dtype={inp.dtype!r} shape={inp.shape!r}"
236+
f"{inp.ravel().tolist()!r}"
237+
)
238+
try:
239+
return run_function(obj, *inputs)
240+
except Exception as e:
241+
raise AssertionError(
242+
f"Unable to run test {backend_test.name!r} after conversion.\n"
243+
f"{onnxscript.proto2text(proto)}"
244+
) from e
245+
246+
backend_test.run(_load_function, _run_function)
247+
248+
# Check eager mode
249+
backend_test.run(lambda _: main_function, exec_main)
250+
251+
252+
if __name__ == "__main__":
253+
unittest.main(verbosity=2)

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3072,7 +3072,12 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
30723072
def aten_logsumexp(self: TReal, dim: INT64, keepdim: int = False) -> TReal:
30733073
"""logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"""
30743074

3075-
return op.ReduceLogSumExp(self, dim, keepdims=keepdim)
3075+
if op.Size(op.Shape(self)) == 0:
3076+
# A scalar
3077+
result = self
3078+
else:
3079+
result = op.ReduceLogSumExp(self, dim, keepdims=keepdim)
3080+
return result
30763081

30773082

30783083
def aten_lshift(self: TensorType, other: TensorType) -> TensorType:

0 commit comments

Comments
 (0)