7
7
from typing import Any , Callable , Collection , Iterable , Optional , Sequence , TypeVar
8
8
9
9
import numpy as np
10
+ import onnx
10
11
import onnxruntime .capi .onnxruntime_pybind11_state
12
+ import parameterized
11
13
import torch
12
14
from torch .testing ._internal import common_device_type , common_methods_invocations
13
15
from torch .testing ._internal .opinfo import core as opinfo_core
@@ -69,14 +71,15 @@ class DecorateMeta:
69
71
decorator : Callable [..., Any ]
70
72
dtypes : Optional [Collection [torch .dtype ]]
71
73
reason : str
74
+ matcher : Optional [Callable [[Any ], bool ]] = None
72
75
73
76
74
77
def xfail (
75
78
op_name : str ,
76
79
variant_name : str = "" ,
77
80
* ,
81
+ reason : str ,
78
82
dtypes : Optional [Collection [torch .dtype ]] = None ,
79
- reason : Optional [str ] = None ,
80
83
):
81
84
"""Expects an OpInfo test to fail.
82
85
@@ -86,8 +89,6 @@ def xfail(
86
89
dtypes: The dtypes to expect the failure.
87
90
reason: The reason for the failure.
88
91
"""
89
- if reason is None :
90
- raise ValueError ("Please specify a reason." )
91
92
return DecorateMeta (
92
93
op_name = op_name ,
93
94
variant_name = variant_name ,
@@ -101,8 +102,9 @@ def skip(
101
102
op_name : str ,
102
103
variant_name : str = "" ,
103
104
* ,
105
+ reason : str ,
104
106
dtypes : Optional [Collection [torch .dtype ]] = None ,
105
- reason : Optional [str ] = None ,
107
+ matcher : Optional [Callable [[ Any ], Any ] ] = None ,
106
108
):
107
109
"""Skips an OpInfo test.
108
110
@@ -111,15 +113,16 @@ def skip(
111
113
variant_name: Optional OpInfo variant_test_name.
112
114
dtypes: The dtypes to skip.
113
115
reason: The reason for skipping.
116
+ matcher: A function that matches the test sample input. It is used only when
117
+ xfail is in the SKIP_SUBTESTS list.
114
118
"""
115
- if reason is None :
116
- raise ValueError ("Please specify a reason." )
117
119
return DecorateMeta (
118
120
op_name = op_name ,
119
121
variant_name = variant_name ,
120
122
decorator = unittest .skip (f"Don't care: { reason } " ),
121
123
dtypes = dtypes ,
122
124
reason = reason ,
125
+ matcher = matcher ,
123
126
)
124
127
125
128
@@ -159,17 +162,28 @@ def wrapped(fn):
159
162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
160
163
OPINFO_FUNCTION_MAPPING : dict [str , Callable [..., Any ]] = {
161
164
"add" : core_ops .aten_add ,
165
+ "clamp_max" : core_ops .aten_clamp_max_tensor ,
166
+ "clamp_min" : core_ops .aten_clamp_min_tensor ,
167
+ "gt" : core_ops .aten_gt ,
168
+ "lt" : core_ops .aten_lt ,
162
169
"mul" : core_ops .aten_mul ,
163
170
"nn.functional.elu" : nn_ops .aten_elu ,
164
171
"nn.functional.relu6" : nn_ops .aten_relu6 ,
165
172
"nn.functional.selu" : core_ops .aten_selu ,
173
+ "ones_like" : core_ops .aten_ones_like_dtype ,
174
+ "repeat" : core_ops .aten_repeat ,
175
+ "round" : core_ops .aten_round ,
166
176
"sub" : core_ops .aten_sub ,
167
177
}
168
178
169
179
TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
170
180
171
181
EXPECTED_SKIPS_OR_FAILS = (
172
182
xfail ("add" , dtypes = BOOL_TYPES , reason = "Add is not defined on bool tensors" ),
183
+ xfail ("clamp_max" , dtypes = BOOL_TYPES , reason = "Min is not defined on bool tensors" ),
184
+ xfail ("clamp_min" , dtypes = BOOL_TYPES , reason = "Max is not defined on bool tensors" ),
185
+ xfail ("gt" , dtypes = BOOL_TYPES , reason = "Greater is not defined on bool tensors" ),
186
+ xfail ("lt" , dtypes = BOOL_TYPES , reason = "Less is not defined on bool tensors" ),
173
187
xfail ("mul" , dtypes = BOOL_TYPES , reason = "Mul is not defined on bool tensors" ),
174
188
xfail (
175
189
"nn.functional.elu" ,
@@ -186,14 +200,123 @@ def wrapped(fn):
186
200
dtypes = dtypes_except (torch .float16 , torch .float32 ),
187
201
reason = "ONNX Runtime doesn't support float64 for Selu" ,
188
202
),
203
+ xfail (
204
+ "round" ,
205
+ variant_name = "" ,
206
+ dtypes = dtypes_except (* FLOAT_TYPES ),
207
+ reason = "Round is not defined on non-float tensors" ,
208
+ ),
209
+ xfail ("round" , variant_name = "decimals_0" , reason = "The ATen op does not support decimals" ),
210
+ xfail ("round" , variant_name = "decimals_3" , reason = "The ATen op does not support decimals" ),
211
+ xfail (
212
+ "round" , variant_name = "decimals_neg_3" , reason = "The ATen op does not support decimals"
213
+ ),
189
214
xfail ("sub" , dtypes = BOOL_TYPES , reason = "Sub is not defined on bool tensors" ),
190
215
)
216
+
217
+
218
+ SKIP_SUBTESTS = (
219
+ skip (
220
+ "clamp_max" ,
221
+ reason = "Empty tensor not yet supported" ,
222
+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
223
+ ),
224
+ skip (
225
+ "clamp_min" ,
226
+ reason = "Empty tensor not yet supported" ,
227
+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
228
+ ),
229
+ skip (
230
+ "repeat" ,
231
+ reason = "repeating when input is a scalar and repeats is empty is not supported" ,
232
+ matcher = lambda sample : sample .args [0 ] == (),
233
+ ),
234
+ skip (
235
+ "ones_like" ,
236
+ # TODO(justinchuby): Test aten_ones_like
237
+ reason = "dtype must be provided for aten_ones_like_dtype" ,
238
+ matcher = lambda sample : "dtype" not in sample .kwargs ,
239
+ ),
240
+ )
241
+ OP_WITH_SKIPPED_SUBTESTS = frozenset (meta .op_name for meta in SKIP_SUBTESTS )
242
+
191
243
# END OF SECTION TO MODIFY #####################################################
192
244
193
245
194
246
OPS_DB = copy .deepcopy (common_methods_invocations .op_db )
195
247
196
248
249
+ TORCH_TYPE_TO_ONNX = {
250
+ torch .bool : onnx .TensorProto .BOOL ,
251
+ torch .uint8 : onnx .TensorProto .UINT8 ,
252
+ torch .int8 : onnx .TensorProto .INT8 ,
253
+ torch .int16 : onnx .TensorProto .INT16 ,
254
+ torch .int32 : onnx .TensorProto .INT32 ,
255
+ torch .int64 : onnx .TensorProto .INT64 ,
256
+ torch .float16 : onnx .TensorProto .FLOAT16 ,
257
+ torch .float32 : onnx .TensorProto .FLOAT ,
258
+ torch .float64 : onnx .TensorProto .DOUBLE ,
259
+ torch .complex64 : onnx .TensorProto .COMPLEX64 ,
260
+ torch .complex128 : onnx .TensorProto .COMPLEX128 ,
261
+ torch .bfloat16 : onnx .TensorProto .BFLOAT16 ,
262
+ }
263
+
264
+
265
+ class TestFunctionsCompilation (unittest .TestCase ):
266
+ """Test all functions can be compiled."""
267
+
268
+ @parameterized .parameterized .expand (
269
+ list (OPINFO_FUNCTION_MAPPING .items ()),
270
+ )
271
+ def test_function_compiles (self , _ , function ):
272
+ compiled = onnxscript .script ()(function )
273
+ compiled .to_function_proto ()
274
+
275
+
276
+ def _convert_tensor_to_numpy (input : Any ) -> Any :
277
+ if isinstance (input , torch .Tensor ):
278
+ return input .detach ().cpu ().numpy ()
279
+ if isinstance (input , (tuple , list )):
280
+ if len (input ) == 0 :
281
+ return np .array ((), dtype = np .int64 )
282
+ if isinstance (input [0 ], torch .Tensor ):
283
+ return [_convert_tensor_to_numpy (x ) for x in input ]
284
+ if isinstance (input [0 ], (int , float )):
285
+ # Just a tuple of numbers
286
+ return np .array (input )
287
+ return input
288
+
289
+ return input
290
+
291
+
292
+ def _convert_kwargs_for_onnx (kwargs : dict [str , Any ]) -> dict [str , Any ]:
293
+ """Converts kwargs to be compatible with ONNX Runtime.
294
+
295
+ ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
296
+ """
297
+ new_kwargs = {}
298
+ for key , value in kwargs .items ():
299
+ if key == "device" :
300
+ continue
301
+ if key == "dtype" :
302
+ value = TORCH_TYPE_TO_ONNX [value ]
303
+ new_kwargs [key ] = value
304
+ return new_kwargs
305
+
306
+
307
+ def _should_skip_test_sample (op_name : str , sample ) -> Optional [str ]:
308
+ """Returns a reason if a test sample should be skipped."""
309
+ if op_name not in OP_WITH_SKIPPED_SUBTESTS :
310
+ return None
311
+ for decorator_meta in SKIP_SUBTESTS :
312
+ # Linear search on SKIP_SUBTESTS. That's fine because the list is small.
313
+ if decorator_meta .op_name == op_name :
314
+ assert decorator_meta .matcher is not None , "Matcher must be defined"
315
+ if decorator_meta .matcher (sample ):
316
+ return decorator_meta .reason
317
+ return None
318
+
319
+
197
320
class TestOutputConsistency (unittest .TestCase ):
198
321
"""Test output consistency between exported ONNX models and PyTorch eager mode.
199
322
@@ -236,10 +359,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
236
359
inputs = repr (inputs ),
237
360
kwargs = repr (cpu_sample .kwargs ),
238
361
):
239
- input_numpy = [x .numpy () for x in inputs if isinstance (x , torch .Tensor )]
240
- torch_output = op (* inputs , ** cpu_sample .kwargs )
362
+ skip_reason = _should_skip_test_sample (op .name , cpu_sample )
363
+ if skip_reason is not None :
364
+ self .skipTest (skip_reason )
365
+ input_onnx = [_convert_tensor_to_numpy (x ) for x in inputs ]
366
+ kwargs_onnx = _convert_kwargs_for_onnx (cpu_sample .kwargs )
367
+ output_torch = op (* inputs , ** cpu_sample .kwargs )
241
368
try :
242
- function_output = scripted_function (* input_numpy , ** cpu_sample . kwargs )
369
+ function_output = scripted_function (* input_onnx , ** kwargs_onnx )
243
370
# pylint: disable=c-extension-no-member
244
371
except onnxruntime .capi .onnxruntime_pybind11_state .NotImplemented :
245
372
self .skipTest (
@@ -250,7 +377,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
250
377
# Use torch testing to ensure dtypes and shapes match
251
378
torch .testing .assert_close (
252
379
torch .tensor (function_output ),
253
- torch_output ,
380
+ output_torch ,
254
381
)
255
382
256
383
0 commit comments