Skip to content

Commit 7e536ef

Browse files
committed
feat(atenlib): Create sample functions and tests
ghstack-source-id: 22d5615 Pull Request resolved: #208
1 parent 656999f commit 7e536ef

File tree

3 files changed

+231
-2
lines changed

3 files changed

+231
-2
lines changed

onnxscript/fuction_libs/torch_aten/ops/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def decorator(func):
2222

2323

2424
# TODO: put this in nn
25-
@atenop("aten::relu6")
26-
def Relu6(self):
25+
@atenop("aten::relu")
26+
def Relu(self):
2727
zero = op.CastLike(0, self)
2828
return op.Max(self, zero)
2929

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Test op correctness by comparing with PyTorch results."""
2+
from __future__ import annotations
3+
4+
import copy
5+
import dataclasses
6+
import unittest
7+
from typing import Callable, Collection, Iterable, Optional, Sequence
8+
9+
import numpy as np
10+
import torch
11+
from torch.testing._internal import common_device_type, common_methods_invocations
12+
from torch.testing._internal.opinfo import core as opinfo_core
13+
14+
import onnxscript
15+
from onnxscript.fuction_libs.torch_aten.ops import core as core_ops
16+
17+
SUPPORTED_DTYPES = (
18+
# Boolean
19+
torch.bool,
20+
# Integers
21+
torch.uint8,
22+
torch.int8,
23+
torch.int16,
24+
torch.int32,
25+
torch.int64,
26+
# Floating types
27+
torch.float16,
28+
torch.float32,
29+
torch.float64,
30+
)
31+
32+
# Convenience tuples for creating dtype lists when skipping or xfailing tests
33+
34+
BOOL_TYPES = (torch.bool,)
35+
36+
INT_TYPES = (
37+
torch.int8,
38+
torch.int16,
39+
torch.int32,
40+
torch.int64,
41+
torch.uint8,
42+
)
43+
44+
FLOAT_TYPES = (
45+
torch.float16,
46+
torch.float32,
47+
torch.float64,
48+
)
49+
50+
51+
@dataclasses.dataclass
52+
class DecorateMeta:
53+
"""A dataclass for storing information about a test case to skip or xfail.
54+
55+
Adapted from functorch: functorch/test/common_utils.py
56+
"""
57+
58+
op_name: str
59+
variant_name: str
60+
decorator: Callable
61+
dtypes: Optional[Collection[torch.dtype]]
62+
reason: str
63+
64+
65+
def xfail(
66+
op_name: str,
67+
variant_name: str = "",
68+
*,
69+
dtypes: Optional[Collection[torch.dtype]] = None,
70+
reason: Optional[str] = None,
71+
):
72+
"""Expects an OpInfo test to fail.
73+
74+
Args:
75+
op_name: The name of the operator.
76+
variant_name: Optional OpInfo variant_test_name.
77+
dtypes: The dtypes to expect the failure.
78+
reason: The reason for the failure.
79+
"""
80+
if reason is None:
81+
raise ValueError("Please specify a reason.")
82+
return DecorateMeta(
83+
op_name=op_name,
84+
variant_name=variant_name,
85+
decorator=unittest.expectedFailure,
86+
dtypes=dtypes,
87+
reason=reason,
88+
)
89+
90+
91+
def skip(
92+
op_name: str,
93+
variant_name: str = "",
94+
*,
95+
dtypes: Optional[Collection[torch.dtype]] = None,
96+
reason: Optional[str] = None,
97+
):
98+
"""Skips an OpInfo test.
99+
100+
Args:
101+
op_name: The name of the operator.
102+
variant_name: Optional OpInfo variant_test_name.
103+
dtypes: The dtypes to skip.
104+
reason: The reason for skipping.
105+
"""
106+
if reason is None:
107+
raise ValueError("Please specify a reason.")
108+
return DecorateMeta(
109+
op_name=op_name,
110+
variant_name=variant_name,
111+
decorator=unittest.skip(f"Don't care: {reason}"),
112+
dtypes=dtypes,
113+
reason=reason,
114+
)
115+
116+
117+
def add_decorate_info(
118+
all_opinfos: Sequence[opinfo_core.OpInfo],
119+
test_class_name: str,
120+
base_test_name: str,
121+
skip_or_xfails: Iterable[DecorateMeta],
122+
):
123+
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list."""
124+
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
125+
for decorate_meta in skip_or_xfails:
126+
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
127+
assert (
128+
opinfo is not None
129+
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
130+
decorators = list(opinfo.decorators)
131+
new_decorator = opinfo_core.DecorateInfo(
132+
decorate_meta.decorator,
133+
test_class_name,
134+
base_test_name,
135+
dtypes=decorate_meta.dtypes,
136+
)
137+
decorators.append(new_decorator)
138+
opinfo.decorators = tuple(decorators)
139+
140+
# This decorator doesn't modify fn in any way
141+
def wrapped(fn):
142+
return fn
143+
144+
return wrapped
145+
146+
147+
# Modify this section ##########################################################
148+
149+
# Ops to be tested for numerical consistency between onnx and pytorch
150+
OPINFO_FUNCTION_MAPPING = {
151+
"nn.functional.elu": core_ops.Elu,
152+
"nn.functional.relu": core_ops.Relu,
153+
"nn.functional.selu": core_ops.Selu,
154+
}
155+
156+
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
157+
158+
EXPECTED_SKIPS_OR_FAILS = (
159+
xfail(
160+
"nn.functional.elu", dtypes=[torch.float64], reason="ORT does not support Elu float64"
161+
),
162+
)
163+
# END OF SECTION TO MODIFY #####################################################
164+
165+
166+
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
167+
168+
169+
class TestOutputConsistency(unittest.TestCase):
170+
"""Test output consistency between exported ONNX models and PyTorch eager mode.
171+
172+
This is a parameterized test suite.
173+
"""
174+
175+
def setUp(self) -> None:
176+
torch.manual_seed(42)
177+
np.random.seed(42)
178+
179+
@common_device_type.ops(
180+
[info for info in OPS_DB if info.name in TESTED_OPS],
181+
allowed_dtypes=SUPPORTED_DTYPES,
182+
)
183+
@add_decorate_info(
184+
OPS_DB,
185+
"TestOutputConsistency",
186+
"test_output_match",
187+
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
188+
)
189+
def test_output_match(self, device: str, dtype: torch.dtype, op):
190+
"""Base test method for testing each opset, used by instantiate_device_type_tests."""
191+
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
192+
assert device == "cpu"
193+
194+
samples = op.sample_inputs(
195+
device,
196+
dtype,
197+
requires_grad=False,
198+
)
199+
200+
onnx_function = OPINFO_FUNCTION_MAPPING[op.name]
201+
scripted_function = onnxscript.script()(onnx_function)
202+
203+
for (i, cpu_sample) in enumerate(samples):
204+
inputs = (cpu_sample.input, *cpu_sample.args)
205+
# Provide the repr to subtest because tensors are not serializable in parallel test runs
206+
with self.subTest(
207+
sample_num=i,
208+
inputs=repr(inputs),
209+
kwargs=repr(cpu_sample.kwargs),
210+
):
211+
input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)]
212+
function_output = scripted_function(*input_numpy, **cpu_sample.kwargs)
213+
torch_output = op(*inputs, **cpu_sample.kwargs)
214+
215+
np.testing.assert_allclose(
216+
function_output,
217+
torch_output.numpy(),
218+
)
219+
220+
221+
common_device_type.instantiate_device_type_tests(
222+
TestOutputConsistency, globals(), only_for="cpu"
223+
)
224+
225+
226+
if __name__ == "__main__":
227+
unittest.main()

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pytest-subtests
1919
pytest-xdist
2020
parameterized
2121
torch
22+
expecttest
23+
pyyaml
2224

2325
# Lint
2426
lintrunner

0 commit comments

Comments
 (0)