77from typing import Any , Callable , Collection , Iterable , Optional , Sequence , TypeVar
88
99import numpy as np
10+ import onnx
1011import onnxruntime .capi .onnxruntime_pybind11_state
12+ import parameterized
1113import torch
1214from torch .testing ._internal import common_device_type , common_methods_invocations
1315from torch .testing ._internal .opinfo import core as opinfo_core
@@ -69,14 +71,15 @@ class DecorateMeta:
6971 decorator : Callable [..., Any ]
7072 dtypes : Optional [Collection [torch .dtype ]]
7173 reason : str
74+ matcher : Optional [Callable [[Any ], bool ]] = None
7275
7376
7477def xfail (
7578 op_name : str ,
7679 variant_name : str = "" ,
7780 * ,
81+ reason : str ,
7882 dtypes : Optional [Collection [torch .dtype ]] = None ,
79- reason : Optional [str ] = None ,
8083):
8184 """Expects an OpInfo test to fail.
8285
@@ -86,8 +89,6 @@ def xfail(
8689 dtypes: The dtypes to expect the failure.
8790 reason: The reason for the failure.
8891 """
89- if reason is None :
90- raise ValueError ("Please specify a reason." )
9192 return DecorateMeta (
9293 op_name = op_name ,
9394 variant_name = variant_name ,
@@ -101,8 +102,9 @@ def skip(
101102 op_name : str ,
102103 variant_name : str = "" ,
103104 * ,
105+ reason : str ,
104106 dtypes : Optional [Collection [torch .dtype ]] = None ,
105- reason : Optional [str ] = None ,
107+ matcher : Optional [Callable [[ Any ], Any ] ] = None ,
106108):
107109 """Skips an OpInfo test.
108110
@@ -111,15 +113,16 @@ def skip(
111113 variant_name: Optional OpInfo variant_test_name.
112114 dtypes: The dtypes to skip.
113115 reason: The reason for skipping.
116+ matcher: A function that matches the test sample input. It is used only when
117+ xfail is in the SKIP_SUBTESTS list.
114118 """
115- if reason is None :
116- raise ValueError ("Please specify a reason." )
117119 return DecorateMeta (
118120 op_name = op_name ,
119121 variant_name = variant_name ,
120122 decorator = unittest .skip (f"Don't care: { reason } " ),
121123 dtypes = dtypes ,
122124 reason = reason ,
125+ matcher = matcher ,
123126 )
124127
125128
@@ -159,17 +162,29 @@ def wrapped(fn):
159162# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
160163OPINFO_FUNCTION_MAPPING : dict [str , Callable [..., Any ]] = {
161164 "add" : core_ops .aten_add ,
165+ # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166+ "clamp_max" : core_ops .aten_clamp_max_tensor ,
167+ "clamp_min" : core_ops .aten_clamp_min_tensor ,
168+ "gt" : core_ops .aten_gt ,
169+ "lt" : core_ops .aten_lt ,
162170 "mul" : core_ops .aten_mul ,
163171 "nn.functional.elu" : nn_ops .aten_elu ,
164172 "nn.functional.relu6" : nn_ops .aten_relu6 ,
165173 "nn.functional.selu" : core_ops .aten_selu ,
174+ "ones_like" : core_ops .aten_ones_like ,
175+ "repeat" : core_ops .aten_repeat ,
176+ "round" : core_ops .aten_round ,
166177 "sub" : core_ops .aten_sub ,
167178}
168179
169180TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
170181
171182EXPECTED_SKIPS_OR_FAILS = (
172183 xfail ("add" , dtypes = BOOL_TYPES , reason = "Add is not defined on bool tensors" ),
184+ xfail ("clamp_max" , dtypes = BOOL_TYPES , reason = "Min is not defined on bool tensors" ),
185+ xfail ("clamp_min" , dtypes = BOOL_TYPES , reason = "Max is not defined on bool tensors" ),
186+ xfail ("gt" , dtypes = BOOL_TYPES , reason = "Greater is not defined on bool tensors" ),
187+ xfail ("lt" , dtypes = BOOL_TYPES , reason = "Less is not defined on bool tensors" ),
173188 xfail ("mul" , dtypes = BOOL_TYPES , reason = "Mul is not defined on bool tensors" ),
174189 xfail (
175190 "nn.functional.elu" ,
@@ -186,14 +201,117 @@ def wrapped(fn):
186201 dtypes = dtypes_except (torch .float16 , torch .float32 ),
187202 reason = "ONNX Runtime doesn't support float64 for Selu" ,
188203 ),
204+ xfail (
205+ "round" ,
206+ variant_name = "" ,
207+ dtypes = dtypes_except (* FLOAT_TYPES ),
208+ reason = "Round is not defined on non-float tensors" ,
209+ ),
210+ xfail ("round" , variant_name = "decimals_0" , reason = "The ATen op does not support decimals" ),
211+ xfail ("round" , variant_name = "decimals_3" , reason = "The ATen op does not support decimals" ),
212+ xfail (
213+ "round" , variant_name = "decimals_neg_3" , reason = "The ATen op does not support decimals"
214+ ),
189215 xfail ("sub" , dtypes = BOOL_TYPES , reason = "Sub is not defined on bool tensors" ),
190216)
217+
218+
219+ SKIP_SUBTESTS = (
220+ skip (
221+ "clamp_max" ,
222+ reason = "Empty tensor not yet supported" ,
223+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
224+ ),
225+ skip (
226+ "clamp_min" ,
227+ reason = "Empty tensor not yet supported" ,
228+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
229+ ),
230+ skip (
231+ "repeat" ,
232+ reason = "repeating when input is a scalar and repeats is empty is not supported" ,
233+ matcher = lambda sample : sample .args [0 ] == (),
234+ ),
235+ )
236+ OP_WITH_SKIPPED_SUBTESTS = frozenset (meta .op_name for meta in SKIP_SUBTESTS )
237+
191238# END OF SECTION TO MODIFY #####################################################
192239
193240
194241OPS_DB = copy .deepcopy (common_methods_invocations .op_db )
195242
196243
244+ TORCH_TYPE_TO_ONNX = {
245+ torch .bool : onnx .TensorProto .BOOL ,
246+ torch .uint8 : onnx .TensorProto .UINT8 ,
247+ torch .int8 : onnx .TensorProto .INT8 ,
248+ torch .int16 : onnx .TensorProto .INT16 ,
249+ torch .int32 : onnx .TensorProto .INT32 ,
250+ torch .int64 : onnx .TensorProto .INT64 ,
251+ torch .float16 : onnx .TensorProto .FLOAT16 ,
252+ torch .float32 : onnx .TensorProto .FLOAT ,
253+ torch .float64 : onnx .TensorProto .DOUBLE ,
254+ torch .complex64 : onnx .TensorProto .COMPLEX64 ,
255+ torch .complex128 : onnx .TensorProto .COMPLEX128 ,
256+ torch .bfloat16 : onnx .TensorProto .BFLOAT16 ,
257+ }
258+
259+
260+ class TestFunctionsCompilation (unittest .TestCase ):
261+ """Test all functions can be compiled."""
262+
263+ @parameterized .parameterized .expand (
264+ list (OPINFO_FUNCTION_MAPPING .items ()),
265+ )
266+ def test_function_compiles (self , _ , function ):
267+ compiled = onnxscript .script ()(function )
268+ compiled .to_function_proto ()
269+
270+
271+ def _convert_tensor_to_numpy (input : Any ) -> Any :
272+ if isinstance (input , torch .Tensor ):
273+ return input .detach ().cpu ().numpy ()
274+ if isinstance (input , (tuple , list )):
275+ if len (input ) == 0 :
276+ return np .array ((), dtype = np .int64 )
277+ if isinstance (input [0 ], torch .Tensor ):
278+ return [_convert_tensor_to_numpy (x ) for x in input ]
279+ if isinstance (input [0 ], (int , float )):
280+ # Just a tuple of numbers
281+ return np .array (input )
282+ return input
283+
284+ return input
285+
286+
287+ def _convert_kwargs_for_onnx (kwargs : dict [str , Any ]) -> dict [str , Any ]:
288+ """Converts kwargs to be compatible with ONNX Runtime.
289+
290+ ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
291+ """
292+ new_kwargs = {}
293+ for key , value in kwargs .items ():
294+ if key == "device" :
295+ continue
296+ if key == "dtype" :
297+ value = TORCH_TYPE_TO_ONNX [value ]
298+ new_kwargs [key ] = value
299+ return new_kwargs
300+
301+
302+ def _should_skip_test_sample (op_name : str , sample ) -> Optional [str ]:
303+ """Returns a reason if a test sample should be skipped."""
304+ if op_name not in OP_WITH_SKIPPED_SUBTESTS :
305+ return None
306+ for decorator_meta in SKIP_SUBTESTS :
307+ # Linear search on SKIP_SUBTESTS. That's fine because the list is small.
308+ if decorator_meta .op_name == op_name :
309+ assert decorator_meta .matcher is not None , "Matcher must be defined"
310+ if decorator_meta .matcher (sample ):
311+ return decorator_meta .reason
312+ return None
313+
314+
197315class TestOutputConsistency (unittest .TestCase ):
198316 """Test output consistency between exported ONNX models and PyTorch eager mode.
199317
@@ -236,10 +354,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
236354 inputs = repr (inputs ),
237355 kwargs = repr (cpu_sample .kwargs ),
238356 ):
239- input_numpy = [x .numpy () for x in inputs if isinstance (x , torch .Tensor )]
240- torch_output = op (* inputs , ** cpu_sample .kwargs )
357+ skip_reason = _should_skip_test_sample (op .name , cpu_sample )
358+ if skip_reason is not None :
359+ self .skipTest (skip_reason )
360+ input_onnx = [_convert_tensor_to_numpy (x ) for x in inputs ]
361+ kwargs_onnx = _convert_kwargs_for_onnx (cpu_sample .kwargs )
362+ output_torch = op (* inputs , ** cpu_sample .kwargs )
241363 try :
242- function_output = scripted_function (* input_numpy , ** cpu_sample . kwargs )
364+ function_output = scripted_function (* input_onnx , ** kwargs_onnx )
243365 # pylint: disable=c-extension-no-member
244366 except onnxruntime .capi .onnxruntime_pybind11_state .NotImplemented :
245367 self .skipTest (
@@ -250,7 +372,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
250372 # Use torch testing to ensure dtypes and shapes match
251373 torch .testing .assert_close (
252374 torch .tensor (function_output ),
253- torch_output ,
375+ output_torch ,
254376 )
255377
256378
0 commit comments