|
| 1 | +"""Test op correctness by comparing with PyTorch results.""" |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import copy |
| 5 | +import dataclasses |
| 6 | +import unittest |
| 7 | +from typing import Callable, Collection, Iterable, Optional, Sequence |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +from torch.testing._internal import common_device_type, common_methods_invocations |
| 12 | +from torch.testing._internal.opinfo import core as opinfo_core |
| 13 | + |
| 14 | +import onnxscript |
| 15 | +from onnxscript.fuction_libs.torch_aten.ops import core as core_ops |
| 16 | + |
| 17 | +SUPPORTED_DTYPES = ( |
| 18 | + # Boolean |
| 19 | + torch.bool, |
| 20 | + # Integers |
| 21 | + torch.uint8, |
| 22 | + torch.int8, |
| 23 | + torch.int16, |
| 24 | + torch.int32, |
| 25 | + torch.int64, |
| 26 | + # Floating types |
| 27 | + torch.float16, |
| 28 | + torch.float32, |
| 29 | + torch.float64, |
| 30 | +) |
| 31 | + |
| 32 | +# Convenience tuples for creating dtype lists when skipping or xfailing tests |
| 33 | + |
| 34 | +BOOL_TYPES = (torch.bool,) |
| 35 | + |
| 36 | +INT_TYPES = ( |
| 37 | + torch.int8, |
| 38 | + torch.int16, |
| 39 | + torch.int32, |
| 40 | + torch.int64, |
| 41 | + torch.uint8, |
| 42 | +) |
| 43 | + |
| 44 | +FLOAT_TYPES = ( |
| 45 | + torch.float16, |
| 46 | + torch.float32, |
| 47 | + torch.float64, |
| 48 | +) |
| 49 | + |
| 50 | + |
| 51 | +@dataclasses.dataclass |
| 52 | +class DecorateMeta: |
| 53 | + """A dataclass for storing information about a test case to skip or xfail. |
| 54 | +
|
| 55 | + Adapted from functorch: functorch/test/common_utils.py |
| 56 | + """ |
| 57 | + |
| 58 | + op_name: str |
| 59 | + variant_name: str |
| 60 | + decorator: Callable |
| 61 | + dtypes: Optional[Collection[torch.dtype]] |
| 62 | + reason: str |
| 63 | + |
| 64 | + |
| 65 | +def xfail( |
| 66 | + op_name: str, |
| 67 | + variant_name: str = "", |
| 68 | + *, |
| 69 | + dtypes: Optional[Collection[torch.dtype]] = None, |
| 70 | + reason: Optional[str] = None, |
| 71 | +): |
| 72 | + """Expects an OpInfo test to fail. |
| 73 | +
|
| 74 | + Args: |
| 75 | + op_name: The name of the operator. |
| 76 | + variant_name: Optional OpInfo variant_test_name. |
| 77 | + dtypes: The dtypes to expect the failure. |
| 78 | + reason: The reason for the failure. |
| 79 | + """ |
| 80 | + if reason is None: |
| 81 | + raise ValueError("Please specify a reason.") |
| 82 | + return DecorateMeta( |
| 83 | + op_name=op_name, |
| 84 | + variant_name=variant_name, |
| 85 | + decorator=unittest.expectedFailure, |
| 86 | + dtypes=dtypes, |
| 87 | + reason=reason, |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +def skip( |
| 92 | + op_name: str, |
| 93 | + variant_name: str = "", |
| 94 | + *, |
| 95 | + dtypes: Optional[Collection[torch.dtype]] = None, |
| 96 | + reason: Optional[str] = None, |
| 97 | +): |
| 98 | + """Skips an OpInfo test. |
| 99 | +
|
| 100 | + Args: |
| 101 | + op_name: The name of the operator. |
| 102 | + variant_name: Optional OpInfo variant_test_name. |
| 103 | + dtypes: The dtypes to skip. |
| 104 | + reason: The reason for skipping. |
| 105 | + """ |
| 106 | + if reason is None: |
| 107 | + raise ValueError("Please specify a reason.") |
| 108 | + return DecorateMeta( |
| 109 | + op_name=op_name, |
| 110 | + variant_name=variant_name, |
| 111 | + decorator=unittest.skip(f"Don't care: {reason}"), |
| 112 | + dtypes=dtypes, |
| 113 | + reason=reason, |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def add_decorate_info( |
| 118 | + all_opinfos: Sequence[opinfo_core.OpInfo], |
| 119 | + test_class_name: str, |
| 120 | + base_test_name: str, |
| 121 | + skip_or_xfails: Iterable[DecorateMeta], |
| 122 | +): |
| 123 | + """Decorates OpInfo tests with decorators based on the skip_or_xfails list.""" |
| 124 | + ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} |
| 125 | + for decorate_meta in skip_or_xfails: |
| 126 | + opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) |
| 127 | + assert ( |
| 128 | + opinfo is not None |
| 129 | + ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" |
| 130 | + decorators = list(opinfo.decorators) |
| 131 | + new_decorator = opinfo_core.DecorateInfo( |
| 132 | + decorate_meta.decorator, |
| 133 | + test_class_name, |
| 134 | + base_test_name, |
| 135 | + dtypes=decorate_meta.dtypes, |
| 136 | + ) |
| 137 | + decorators.append(new_decorator) |
| 138 | + opinfo.decorators = tuple(decorators) |
| 139 | + |
| 140 | + # This decorator doesn't modify fn in any way |
| 141 | + def wrapped(fn): |
| 142 | + return fn |
| 143 | + |
| 144 | + return wrapped |
| 145 | + |
| 146 | + |
| 147 | +# Modify this section ########################################################## |
| 148 | + |
| 149 | +# Ops to be tested for numerical consistency between onnx and pytorch |
| 150 | +OPINFO_FUNCTION_MAPPING = { |
| 151 | + "nn.functional.elu": core_ops.Elu, |
| 152 | + "nn.functional.relu": core_ops.Relu, |
| 153 | + "nn.functional.selu": core_ops.Selu, |
| 154 | +} |
| 155 | + |
| 156 | +TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING) |
| 157 | + |
| 158 | +EXPECTED_SKIPS_OR_FAILS = ( |
| 159 | + xfail( |
| 160 | + "nn.functional.elu", dtypes=[torch.float64], reason="ORT does not support Elu float64" |
| 161 | + ), |
| 162 | +) |
| 163 | +# END OF SECTION TO MODIFY ##################################################### |
| 164 | + |
| 165 | + |
| 166 | +OPS_DB = copy.deepcopy(common_methods_invocations.op_db) |
| 167 | + |
| 168 | + |
| 169 | +class TestOutputConsistency(unittest.TestCase): |
| 170 | + """Test output consistency between exported ONNX models and PyTorch eager mode. |
| 171 | +
|
| 172 | + This is a parameterized test suite. |
| 173 | + """ |
| 174 | + |
| 175 | + def setUp(self) -> None: |
| 176 | + torch.manual_seed(42) |
| 177 | + np.random.seed(42) |
| 178 | + |
| 179 | + @common_device_type.ops( |
| 180 | + [info for info in OPS_DB if info.name in TESTED_OPS], |
| 181 | + allowed_dtypes=SUPPORTED_DTYPES, |
| 182 | + ) |
| 183 | + @add_decorate_info( |
| 184 | + OPS_DB, |
| 185 | + "TestOutputConsistency", |
| 186 | + "test_output_match", |
| 187 | + skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, |
| 188 | + ) |
| 189 | + def test_output_match(self, device: str, dtype: torch.dtype, op): |
| 190 | + """Base test method for testing each opset, used by instantiate_device_type_tests.""" |
| 191 | + # device is provided by instantiate_device_type_tests, but we only want to run in cpu. |
| 192 | + assert device == "cpu" |
| 193 | + |
| 194 | + samples = op.sample_inputs( |
| 195 | + device, |
| 196 | + dtype, |
| 197 | + requires_grad=False, |
| 198 | + ) |
| 199 | + |
| 200 | + onnx_function = OPINFO_FUNCTION_MAPPING[op.name] |
| 201 | + scripted_function = onnxscript.script()(onnx_function) |
| 202 | + |
| 203 | + for (i, cpu_sample) in enumerate(samples): |
| 204 | + inputs = (cpu_sample.input, *cpu_sample.args) |
| 205 | + # Provide the repr to subtest because tensors are not serializable in parallel test runs |
| 206 | + with self.subTest( |
| 207 | + sample_num=i, |
| 208 | + inputs=repr(inputs), |
| 209 | + kwargs=repr(cpu_sample.kwargs), |
| 210 | + ): |
| 211 | + input_numpy = [x.numpy() for x in inputs if isinstance(x, torch.Tensor)] |
| 212 | + function_output = scripted_function(*input_numpy, **cpu_sample.kwargs) |
| 213 | + torch_output = op(*inputs, **cpu_sample.kwargs) |
| 214 | + |
| 215 | + np.testing.assert_allclose( |
| 216 | + function_output, |
| 217 | + torch_output.numpy(), |
| 218 | + ) |
| 219 | + |
| 220 | + |
| 221 | +common_device_type.instantiate_device_type_tests( |
| 222 | + TestOutputConsistency, globals(), only_for="cpu" |
| 223 | +) |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == "__main__": |
| 227 | + unittest.main() |
0 commit comments