@@ -221,7 +221,9 @@ def _topk_input_wrangler(
221
221
222
222
# Ops to be tested for numerical consistency between onnx and pytorch
223
223
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
224
- OPINFO_FUNCTION_MAPPING : dict [
224
+
225
+ # Split the scripted and traced ops to make sure we don't forget to script an op
226
+ OPINFO_FUNCTION_MAPPING_SCRIPTED : dict [
225
227
str ,
226
228
onnxscript .OnnxFunction
227
229
| Callable [..., Any ]
@@ -245,7 +247,6 @@ def _topk_input_wrangler(
245
247
"atan" : core_ops .aten_atan ,
246
248
"atanh" : core_ops .aten_atanh ,
247
249
"bmm" : core_ops .aten_bmm ,
248
- "cat" : core_ops .aten_cat ,
249
250
"ceil" : core_ops .aten_ceil ,
250
251
"clamp_max" : core_ops .aten_clamp_max ,
251
252
"clamp_min" : core_ops .aten_clamp_min ,
@@ -267,18 +268,17 @@ def _topk_input_wrangler(
267
268
"full" : (core_ops .aten_full , _full_input_wrangler ),
268
269
"full_like" : core_ops .aten_full_like ,
269
270
"gt" : core_ops .aten_gt ,
270
- "index_select" : core_ops .aten_index_select ,
271
271
"isinf" : core_ops .aten_isinf ,
272
272
"log" : core_ops .aten_log ,
273
273
"log10" : core_ops .aten_log10 ,
274
274
"log1p" : core_ops .aten_log1p ,
275
+ "log_softmax" : (special_ops .aten_special_log_softmax , _log_softmax_input_wrangler ),
275
276
"log2" : core_ops .aten_log2 ,
276
277
"logaddexp" : core_ops .aten_logaddexp ,
277
278
"logaddexp2" : core_ops .aten_logaddexp2 ,
278
279
"logcumsumexp" : core_ops .aten_logcumsumexp ,
279
280
"logdet" : core_ops .aten_logdet ,
280
281
"logsumexp" : (core_ops .aten_logsumexp , _logcumsumexp_input_wrangler ),
281
- "log_softmax" : (special_ops .aten_special_log_softmax , _log_softmax_input_wrangler ),
282
282
"lt" : core_ops .aten_lt ,
283
283
"matmul" : core_ops .aten_matmul ,
284
284
"mm" : core_ops .aten_mm ,
@@ -319,7 +319,6 @@ def _topk_input_wrangler(
319
319
"t" : core_ops .aten_t ,
320
320
"tan" : core_ops .aten_tan ,
321
321
"tanh" : core_ops .aten_tanh ,
322
- "transpose" : core_ops .aten_transpose ,
323
322
"topk" : (
324
323
core_ops .aten_topk ,
325
324
_topk_input_wrangler ,
@@ -332,6 +331,26 @@ def _topk_input_wrangler(
332
331
"zeros_like" : core_ops .aten_zeros_like ,
333
332
}
334
333
334
+
335
+ OPINFO_FUNCTION_MAPPING_TRACE_ONLY : dict [
336
+ str ,
337
+ Callable [..., Any ] | tuple [Callable [..., Any ], Callable [..., Any ]],
338
+ ] = {
339
+ "cat" : core_ops .aten_cat ,
340
+ "index_select" : core_ops .aten_index_select ,
341
+ "transpose" : core_ops .aten_transpose ,
342
+ }
343
+
344
+ OPINFO_FUNCTION_MAPPING : dict [
345
+ str ,
346
+ onnxscript .OnnxFunction
347
+ | Callable [..., Any ]
348
+ | tuple [
349
+ onnxscript .OnnxFunction | Callable [..., Any ],
350
+ Callable [[list [Any ], dict [str , Any ]], tuple [list [Any ], dict [str , Any ]]],
351
+ ],
352
+ ] = {** OPINFO_FUNCTION_MAPPING_SCRIPTED , ** OPINFO_FUNCTION_MAPPING_TRACE_ONLY }
353
+
335
354
TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
336
355
337
356
EXPECTED_SKIPS_OR_FAILS = (
@@ -522,6 +541,14 @@ def setUp(self) -> None:
522
541
torch .manual_seed (42 )
523
542
np .random .seed (42 )
524
543
544
+ def test_all_script_functions_are_onnx_functions (self ):
545
+ for func_with_wrangler in OPINFO_FUNCTION_MAPPING_SCRIPTED .values ():
546
+ if isinstance (func_with_wrangler , tuple ):
547
+ func = func_with_wrangler [0 ]
548
+ else :
549
+ func = func_with_wrangler
550
+ self .assertIsInstance (func , onnxscript .OnnxFunction )
551
+
525
552
@common_device_type .ops ( # type: ignore[misc]
526
553
[info for info in OPS_DB if info .name in TESTED_OPS ],
527
554
allowed_dtypes = TESTED_DTYPES ,
0 commit comments