@@ -162,35 +162,56 @@ def wrapped(fn):
162162# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
163163OPINFO_FUNCTION_MAPPING : dict [str , Callable [..., Any ]] = {
164164 "add" : core_ops .aten_add ,
165- # "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166165 "clamp_max" : core_ops .aten_clamp_max_tensor ,
167166 "clamp_min" : core_ops .aten_clamp_min_tensor ,
167+ "clamp" : core_ops .aten_clamp ,
168168 "gt" : core_ops .aten_gt ,
169169 "lt" : core_ops .aten_lt ,
170+ "matmul" : core_ops .aten_matmul ,
171+ "mm" : core_ops .aten_mm ,
170172 "mul" : core_ops .aten_mul ,
171173 "nn.functional.elu" : nn_ops .aten_elu ,
174+ "nn.functional.linear" : nn_ops .aten_linear ,
172175 "nn.functional.relu6" : nn_ops .aten_relu6 ,
173176 "nn.functional.selu" : core_ops .aten_selu ,
174177 "ones_like" : core_ops .aten_ones_like ,
178+ "ones" : core_ops .aten_ones ,
175179 "repeat" : core_ops .aten_repeat ,
176180 "round" : core_ops .aten_round ,
177181 "sub" : core_ops .aten_sub ,
182+ "t" : core_ops .aten_t ,
183+ # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
178184}
179185
180186TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
181187
182188EXPECTED_SKIPS_OR_FAILS = (
183189 xfail ("add" , dtypes = BOOL_TYPES , reason = "Add is not defined on bool tensors" ),
190+ skip ("clamp" , reason = "Enable when onnxscript errors are fixed" ),
184191 xfail ("clamp_max" , dtypes = BOOL_TYPES , reason = "Min is not defined on bool tensors" ),
185192 xfail ("clamp_min" , dtypes = BOOL_TYPES , reason = "Max is not defined on bool tensors" ),
186193 xfail ("gt" , dtypes = BOOL_TYPES , reason = "Greater is not defined on bool tensors" ),
187194 xfail ("lt" , dtypes = BOOL_TYPES , reason = "Less is not defined on bool tensors" ),
195+ xfail (
196+ "matmul" ,
197+ dtypes = [torch .uint8 , torch .int8 , torch .int16 ],
198+ reason = "MatMul is not defined on int16/int8/uint8 tensors" ,
199+ ),
200+ xfail (
201+ "mm" ,
202+ dtypes = [torch .uint8 , torch .int8 , torch .int16 ],
203+ reason = "MatMul is not defined on int16/int8/uint8 tensors" ,
204+ ),
188205 xfail ("mul" , dtypes = BOOL_TYPES , reason = "Mul is not defined on bool tensors" ),
189206 xfail (
190207 "nn.functional.elu" ,
191208 dtypes = dtypes_except (torch .float16 , torch .float32 ),
192209 reason = "ONNX Runtime doesn't support float64 for Elu" ,
193210 ),
211+ xfail (
212+ "nn.functional.linear" ,
213+ reason = "ONNX Runtime thinks the graph is invalid" ,
214+ ),
194215 xfail (
195216 "nn.functional.relu6" ,
196217 dtypes = dtypes_except (torch .float16 , torch .float32 ),
@@ -213,6 +234,7 @@ def wrapped(fn):
213234 "round" , variant_name = "decimals_neg_3" , reason = "The ATen op does not support decimals"
214235 ),
215236 xfail ("sub" , dtypes = BOOL_TYPES , reason = "Sub is not defined on bool tensors" ),
237+ xfail ("transpose" , reason = "Enable when onnxscript errors are fixed" ),
216238)
217239
218240
@@ -240,6 +262,10 @@ def wrapped(fn):
240262
241263OPS_DB = copy .deepcopy (common_methods_invocations .op_db )
242264
265+ ALL_OPS_IN_DB = frozenset (op_info .name for op_info in OPS_DB )
266+ # Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
267+ assert TESTED_OPS .issubset (ALL_OPS_IN_DB ), f"{ TESTED_OPS - ALL_OPS_IN_DB } not in OPS_DB"
268+
243269
244270TORCH_TYPE_TO_ONNX = {
245271 torch .bool : onnx .TensorProto .BOOL ,
@@ -369,10 +395,21 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
369395 )
370396 # pylint: enable=c-extension-no-member
371397
398+ if dtype == torch .float32 :
399+ # Relax atol and rtol for float32 based on empirical results
400+ # The current most relaxed values are for aten::matmul
401+ rtol = 3.7e-6
402+ atol = 1.8e-5
403+ else :
404+ rtol = None
405+ atol = None
406+
372407 # Use torch testing to ensure dtypes and shapes match
373408 torch .testing .assert_close (
374409 torch .tensor (function_output ),
375410 output_torch ,
411+ rtol = rtol ,
412+ atol = atol ,
376413 )
377414
378415
0 commit comments