@@ -162,35 +162,56 @@ def wrapped(fn):
162
162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
163
163
OPINFO_FUNCTION_MAPPING : dict [str , Callable [..., Any ]] = {
164
164
"add" : core_ops .aten_add ,
165
- # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166
165
"clamp_max" : core_ops .aten_clamp_max_tensor ,
167
166
"clamp_min" : core_ops .aten_clamp_min_tensor ,
167
+ "clamp" : core_ops .aten_clamp ,
168
168
"gt" : core_ops .aten_gt ,
169
169
"lt" : core_ops .aten_lt ,
170
+ "matmul" : core_ops .aten_matmul ,
171
+ "mm" : core_ops .aten_mm ,
170
172
"mul" : core_ops .aten_mul ,
171
173
"nn.functional.elu" : nn_ops .aten_elu ,
174
+ "nn.functional.linear" : nn_ops .aten_linear ,
172
175
"nn.functional.relu6" : nn_ops .aten_relu6 ,
173
176
"nn.functional.selu" : core_ops .aten_selu ,
174
177
"ones_like" : core_ops .aten_ones_like ,
178
+ "ones" : core_ops .aten_ones ,
175
179
"repeat" : core_ops .aten_repeat ,
176
180
"round" : core_ops .aten_round ,
177
181
"sub" : core_ops .aten_sub ,
182
+ "t" : core_ops .aten_t ,
183
+ # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
178
184
}
179
185
180
186
TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
181
187
182
188
EXPECTED_SKIPS_OR_FAILS = (
183
189
xfail ("add" , dtypes = BOOL_TYPES , reason = "Add is not defined on bool tensors" ),
190
+ skip ("clamp" , reason = "Enable when onnxscript errors are fixed" ),
184
191
xfail ("clamp_max" , dtypes = BOOL_TYPES , reason = "Min is not defined on bool tensors" ),
185
192
xfail ("clamp_min" , dtypes = BOOL_TYPES , reason = "Max is not defined on bool tensors" ),
186
193
xfail ("gt" , dtypes = BOOL_TYPES , reason = "Greater is not defined on bool tensors" ),
187
194
xfail ("lt" , dtypes = BOOL_TYPES , reason = "Less is not defined on bool tensors" ),
195
+ xfail (
196
+ "matmul" ,
197
+ dtypes = [torch .uint8 , torch .int8 , torch .int16 ],
198
+ reason = "MatMul is not defined on int16/int8/uint8 tensors" ,
199
+ ),
200
+ xfail (
201
+ "mm" ,
202
+ dtypes = [torch .uint8 , torch .int8 , torch .int16 ],
203
+ reason = "MatMul is not defined on int16/int8/uint8 tensors" ,
204
+ ),
188
205
xfail ("mul" , dtypes = BOOL_TYPES , reason = "Mul is not defined on bool tensors" ),
189
206
xfail (
190
207
"nn.functional.elu" ,
191
208
dtypes = dtypes_except (torch .float16 , torch .float32 ),
192
209
reason = "ONNX Runtime doesn't support float64 for Elu" ,
193
210
),
211
+ xfail (
212
+ "nn.functional.linear" ,
213
+ reason = "ONNX Runtime thinks the graph is invalid" ,
214
+ ),
194
215
xfail (
195
216
"nn.functional.relu6" ,
196
217
dtypes = dtypes_except (torch .float16 , torch .float32 ),
@@ -213,6 +234,7 @@ def wrapped(fn):
213
234
"round" , variant_name = "decimals_neg_3" , reason = "The ATen op does not support decimals"
214
235
),
215
236
xfail ("sub" , dtypes = BOOL_TYPES , reason = "Sub is not defined on bool tensors" ),
237
+ xfail ("transpose" , reason = "Enable when onnxscript errors are fixed" ),
216
238
)
217
239
218
240
@@ -240,6 +262,10 @@ def wrapped(fn):
240
262
241
263
OPS_DB = copy .deepcopy (common_methods_invocations .op_db )
242
264
265
+ ALL_OPS_IN_DB = frozenset (op_info .name for op_info in OPS_DB )
266
+ # Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
267
+ assert TESTED_OPS .issubset (ALL_OPS_IN_DB ), f"{ TESTED_OPS - ALL_OPS_IN_DB } not in OPS_DB"
268
+
243
269
244
270
TORCH_TYPE_TO_ONNX = {
245
271
torch .bool : onnx .TensorProto .BOOL ,
@@ -369,10 +395,21 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
369
395
)
370
396
# pylint: enable=c-extension-no-member
371
397
398
+ if dtype == torch .float32 :
399
+ # Relax atol and rtol for float32 based on empirical results
400
+ # The current most relaxed values are for aten::matmul
401
+ rtol = 3.7e-6
402
+ atol = 1.8e-5
403
+ else :
404
+ rtol = None
405
+ atol = None
406
+
372
407
# Use torch testing to ensure dtypes and shapes match
373
408
torch .testing .assert_close (
374
409
torch .tensor (function_output ),
375
410
output_torch ,
411
+ rtol = rtol ,
412
+ atol = atol ,
376
413
)
377
414
378
415
0 commit comments