77from typing import Any , Callable , Collection , Iterable , Optional , Sequence , TypeVar
88
99import numpy as np
10+ import onnx
1011import onnxruntime .capi .onnxruntime_pybind11_state
12+ import parameterized
1113import torch
1214from torch .testing ._internal import common_device_type , common_methods_invocations
1315from torch .testing ._internal .opinfo import core as opinfo_core
@@ -69,14 +71,15 @@ class DecorateMeta:
6971 decorator : Callable [..., Any ]
7072 dtypes : Optional [Collection [torch .dtype ]]
7173 reason : str
74+ matcher : Optional [Callable [[Any ], bool ]] = None
7275
7376
7477def xfail (
7578 op_name : str ,
7679 variant_name : str = "" ,
7780 * ,
81+ reason : str ,
7882 dtypes : Optional [Collection [torch .dtype ]] = None ,
79- reason : Optional [str ] = None ,
8083):
8184 """Expects an OpInfo test to fail.
8285
@@ -86,8 +89,6 @@ def xfail(
8689 dtypes: The dtypes to expect the failure.
8790 reason: The reason for the failure.
8891 """
89- if reason is None :
90- raise ValueError ("Please specify a reason." )
9192 return DecorateMeta (
9293 op_name = op_name ,
9394 variant_name = variant_name ,
@@ -101,8 +102,9 @@ def skip(
101102 op_name : str ,
102103 variant_name : str = "" ,
103104 * ,
105+ reason : str ,
104106 dtypes : Optional [Collection [torch .dtype ]] = None ,
105- reason : Optional [str ] = None ,
107+ matcher : Optional [Callable [[ Any ], Any ] ] = None ,
106108):
107109 """Skips an OpInfo test.
108110
@@ -111,15 +113,16 @@ def skip(
111113 variant_name: Optional OpInfo variant_test_name.
112114 dtypes: The dtypes to skip.
113115 reason: The reason for skipping.
116+ matcher: A function that matches the test sample input. It is used only when
117+ xfail is in the SKIP_SUBTESTS list.
114118 """
115- if reason is None :
116- raise ValueError ("Please specify a reason." )
117119 return DecorateMeta (
118120 op_name = op_name ,
119121 variant_name = variant_name ,
120122 decorator = unittest .skip (f"Don't care: { reason } " ),
121123 dtypes = dtypes ,
122124 reason = reason ,
125+ matcher = matcher ,
123126 )
124127
125128
@@ -159,17 +162,28 @@ def wrapped(fn):
159162# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
160163OPINFO_FUNCTION_MAPPING : dict [str , Callable [..., Any ]] = {
161164 "add" : core_ops .aten_add ,
165+ "clamp_max" : core_ops .aten_clamp_max_tensor ,
166+ "clamp_min" : core_ops .aten_clamp_min_tensor ,
167+ "gt" : core_ops .aten_gt ,
168+ "lt" : core_ops .aten_lt ,
162169 "mul" : core_ops .aten_mul ,
163170 "nn.functional.elu" : nn_ops .aten_elu ,
164171 "nn.functional.relu6" : nn_ops .aten_relu6 ,
165172 "nn.functional.selu" : core_ops .aten_selu ,
173+ "ones_like" : core_ops .aten_ones_like_dtype ,
174+ "repeat" : core_ops .aten_repeat ,
175+ "round" : core_ops .aten_round ,
166176 "sub" : core_ops .aten_sub ,
167177}
168178
169179TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
170180
171181EXPECTED_SKIPS_OR_FAILS = (
172182 xfail ("add" , dtypes = BOOL_TYPES , reason = "Add is not defined on bool tensors" ),
183+ xfail ("clamp_max" , dtypes = BOOL_TYPES , reason = "Min is not defined on bool tensors" ),
184+ xfail ("clamp_min" , dtypes = BOOL_TYPES , reason = "Max is not defined on bool tensors" ),
185+ xfail ("gt" , dtypes = BOOL_TYPES , reason = "Greater is not defined on bool tensors" ),
186+ xfail ("lt" , dtypes = BOOL_TYPES , reason = "Less is not defined on bool tensors" ),
173187 xfail ("mul" , dtypes = BOOL_TYPES , reason = "Mul is not defined on bool tensors" ),
174188 xfail (
175189 "nn.functional.elu" ,
@@ -186,14 +200,123 @@ def wrapped(fn):
186200 dtypes = dtypes_except (torch .float16 , torch .float32 ),
187201 reason = "ONNX Runtime doesn't support float64 for Selu" ,
188202 ),
203+ xfail (
204+ "round" ,
205+ variant_name = "" ,
206+ dtypes = dtypes_except (* FLOAT_TYPES ),
207+ reason = "Round is not defined on non-float tensors" ,
208+ ),
209+ xfail ("round" , variant_name = "decimals_0" , reason = "The ATen op does not support decimals" ),
210+ xfail ("round" , variant_name = "decimals_3" , reason = "The ATen op does not support decimals" ),
211+ xfail (
212+ "round" , variant_name = "decimals_neg_3" , reason = "The ATen op does not support decimals"
213+ ),
189214 xfail ("sub" , dtypes = BOOL_TYPES , reason = "Sub is not defined on bool tensors" ),
190215)
216+
217+
218+ SKIP_SUBTESTS = (
219+ skip (
220+ "clamp_max" ,
221+ reason = "Empty tensor not yet supported" ,
222+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
223+ ),
224+ skip (
225+ "clamp_min" ,
226+ reason = "Empty tensor not yet supported" ,
227+ matcher = lambda sample : sample .input .size () == torch .Size ([0 ]),
228+ ),
229+ skip (
230+ "repeat" ,
231+ reason = "repeating when input is a scalar and repeats is empty is not supported" ,
232+ matcher = lambda sample : sample .args [0 ] == (),
233+ ),
234+ skip (
235+ "ones_like" ,
236+ # TODO(justinchuby): Test aten_ones_like
237+ reason = "dtype must be provided for aten_ones_like_dtype" ,
238+ matcher = lambda sample : "dtype" not in sample .kwargs ,
239+ ),
240+ )
241+ OP_WITH_SKIPPED_SUBTESTS = frozenset (meta .op_name for meta in SKIP_SUBTESTS )
242+
191243# END OF SECTION TO MODIFY #####################################################
192244
193245
194246OPS_DB = copy .deepcopy (common_methods_invocations .op_db )
195247
196248
249+ TORCH_TYPE_TO_ONNX = {
250+ torch .bool : onnx .TensorProto .BOOL ,
251+ torch .uint8 : onnx .TensorProto .UINT8 ,
252+ torch .int8 : onnx .TensorProto .INT8 ,
253+ torch .int16 : onnx .TensorProto .INT16 ,
254+ torch .int32 : onnx .TensorProto .INT32 ,
255+ torch .int64 : onnx .TensorProto .INT64 ,
256+ torch .float16 : onnx .TensorProto .FLOAT16 ,
257+ torch .float32 : onnx .TensorProto .FLOAT ,
258+ torch .float64 : onnx .TensorProto .DOUBLE ,
259+ torch .complex64 : onnx .TensorProto .COMPLEX64 ,
260+ torch .complex128 : onnx .TensorProto .COMPLEX128 ,
261+ torch .bfloat16 : onnx .TensorProto .BFLOAT16 ,
262+ }
263+
264+
265+ class TestFunctionsCompilation (unittest .TestCase ):
266+ """Test all functions can be compiled."""
267+
268+ @parameterized .parameterized .expand (
269+ list (OPINFO_FUNCTION_MAPPING .items ()),
270+ )
271+ def test_function_compiles (self , _ , function ):
272+ compiled = onnxscript .script ()(function )
273+ compiled .to_function_proto ()
274+
275+
276+ def _convert_tensor_to_numpy (input : Any ) -> Any :
277+ if isinstance (input , torch .Tensor ):
278+ return input .detach ().cpu ().numpy ()
279+ if isinstance (input , (tuple , list )):
280+ if len (input ) == 0 :
281+ return np .array ((), dtype = np .int64 )
282+ if isinstance (input [0 ], torch .Tensor ):
283+ return [_convert_tensor_to_numpy (x ) for x in input ]
284+ if isinstance (input [0 ], (int , float )):
285+ # Just a tuple of numbers
286+ return np .array (input )
287+ return input
288+
289+ return input
290+
291+
292+ def _convert_kwargs_for_onnx (kwargs : dict [str , Any ]) -> dict [str , Any ]:
293+ """Converts kwargs to be compatible with ONNX Runtime.
294+
295+ ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
296+ """
297+ new_kwargs = {}
298+ for key , value in kwargs .items ():
299+ if key == "device" :
300+ continue
301+ if key == "dtype" :
302+ value = TORCH_TYPE_TO_ONNX [value ]
303+ new_kwargs [key ] = value
304+ return new_kwargs
305+
306+
307+ def _should_skip_test_sample (op_name : str , sample ) -> Optional [str ]:
308+ """Returns a reason if a test sample should be skipped."""
309+ if op_name not in OP_WITH_SKIPPED_SUBTESTS :
310+ return None
311+ for decorator_meta in SKIP_SUBTESTS :
312+ # Linear search on SKIP_SUBTESTS. That's fine because the list is small.
313+ if decorator_meta .op_name == op_name :
314+ assert decorator_meta .matcher is not None , "Matcher must be defined"
315+ if decorator_meta .matcher (sample ):
316+ return decorator_meta .reason
317+ return None
318+
319+
197320class TestOutputConsistency (unittest .TestCase ):
198321 """Test output consistency between exported ONNX models and PyTorch eager mode.
199322
@@ -236,10 +359,14 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
236359 inputs = repr (inputs ),
237360 kwargs = repr (cpu_sample .kwargs ),
238361 ):
239- input_numpy = [x .numpy () for x in inputs if isinstance (x , torch .Tensor )]
240- torch_output = op (* inputs , ** cpu_sample .kwargs )
362+ skip_reason = _should_skip_test_sample (op .name , cpu_sample )
363+ if skip_reason is not None :
364+ self .skipTest (skip_reason )
365+ input_onnx = [_convert_tensor_to_numpy (x ) for x in inputs ]
366+ kwargs_onnx = _convert_kwargs_for_onnx (cpu_sample .kwargs )
367+ output_torch = op (* inputs , ** cpu_sample .kwargs )
241368 try :
242- function_output = scripted_function (* input_numpy , ** cpu_sample . kwargs )
369+ function_output = scripted_function (* input_onnx , ** kwargs_onnx )
243370 # pylint: disable=c-extension-no-member
244371 except onnxruntime .capi .onnxruntime_pybind11_state .NotImplemented :
245372 self .skipTest (
@@ -250,7 +377,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
250377 # Use torch testing to ensure dtypes and shapes match
251378 torch .testing .assert_close (
252379 torch .tensor (function_output ),
253- torch_output ,
380+ output_torch ,
254381 )
255382
256383
0 commit comments