@@ -221,7 +221,9 @@ def _topk_input_wrangler(
221221
222222# Ops to be tested for numerical consistency between onnx and pytorch
223223# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
224- OPINFO_FUNCTION_MAPPING : dict [
224+
225+ # Split the scripted and traced ops to make sure we don't forget to script an op
226+ OPINFO_FUNCTION_MAPPING_SCRIPTED : dict [
225227 str ,
226228 onnxscript .OnnxFunction
227229 | Callable [..., Any ]
@@ -245,7 +247,6 @@ def _topk_input_wrangler(
245247 "atan" : core_ops .aten_atan ,
246248 "atanh" : core_ops .aten_atanh ,
247249 "bmm" : core_ops .aten_bmm ,
248- "cat" : core_ops .aten_cat ,
249250 "ceil" : core_ops .aten_ceil ,
250251 "clamp_max" : core_ops .aten_clamp_max ,
251252 "clamp_min" : core_ops .aten_clamp_min ,
@@ -267,18 +268,17 @@ def _topk_input_wrangler(
267268 "full" : (core_ops .aten_full , _full_input_wrangler ),
268269 "full_like" : core_ops .aten_full_like ,
269270 "gt" : core_ops .aten_gt ,
270- "index_select" : core_ops .aten_index_select ,
271271 "isinf" : core_ops .aten_isinf ,
272272 "log" : core_ops .aten_log ,
273273 "log10" : core_ops .aten_log10 ,
274274 "log1p" : core_ops .aten_log1p ,
275+ "log_softmax" : (special_ops .aten_special_log_softmax , _log_softmax_input_wrangler ),
275276 "log2" : core_ops .aten_log2 ,
276277 "logaddexp" : core_ops .aten_logaddexp ,
277278 "logaddexp2" : core_ops .aten_logaddexp2 ,
278279 "logcumsumexp" : core_ops .aten_logcumsumexp ,
279280 "logdet" : core_ops .aten_logdet ,
280281 "logsumexp" : (core_ops .aten_logsumexp , _logcumsumexp_input_wrangler ),
281- "log_softmax" : (special_ops .aten_special_log_softmax , _log_softmax_input_wrangler ),
282282 "lt" : core_ops .aten_lt ,
283283 "matmul" : core_ops .aten_matmul ,
284284 "mm" : core_ops .aten_mm ,
@@ -315,12 +315,12 @@ def _topk_input_wrangler(
315315 "sign" : core_ops .aten_sign ,
316316 "sin" : core_ops .aten_sin ,
317317 "sinh" : core_ops .aten_sinh ,
318+ "slice" : core_ops .aten_slice ,
318319 "sqrt" : core_ops .aten_sqrt ,
319320 "sub" : core_ops .aten_sub ,
320321 "t" : core_ops .aten_t ,
321322 "tan" : core_ops .aten_tan ,
322323 "tanh" : core_ops .aten_tanh ,
323- "transpose" : core_ops .aten_transpose ,
324324 "topk" : (
325325 core_ops .aten_topk ,
326326 _topk_input_wrangler ,
@@ -333,6 +333,26 @@ def _topk_input_wrangler(
333333 "zeros_like" : core_ops .aten_zeros_like ,
334334}
335335
336+
337+ OPINFO_FUNCTION_MAPPING_TRACE_ONLY : dict [
338+ str ,
339+ Callable [..., Any ] | tuple [Callable [..., Any ], Callable [..., Any ]],
340+ ] = {
341+ "cat" : core_ops .aten_cat ,
342+ "index_select" : core_ops .aten_index_select ,
343+ "transpose" : core_ops .aten_transpose ,
344+ }
345+
346+ OPINFO_FUNCTION_MAPPING : dict [
347+ str ,
348+ onnxscript .OnnxFunction
349+ | Callable [..., Any ]
350+ | tuple [
351+ onnxscript .OnnxFunction | Callable [..., Any ],
352+ Callable [[list [Any ], dict [str , Any ]], tuple [list [Any ], dict [str , Any ]]],
353+ ],
354+ ] = {** OPINFO_FUNCTION_MAPPING_SCRIPTED , ** OPINFO_FUNCTION_MAPPING_TRACE_ONLY }
355+
336356TESTED_OPS = frozenset (OPINFO_FUNCTION_MAPPING )
337357
338358EXPECTED_SKIPS_OR_FAILS = (
@@ -420,6 +440,12 @@ def _topk_input_wrangler(
420440 matcher = lambda sample : "scale_factor" in sample .kwargs ,
421441 reason = "fixme: the scale_factor tests" ,
422442 ),
443+ skip (
444+ "slice" ,
445+ # kwargs {dim, start, end, step} is empty, we cannot give the default value
446+ matcher = lambda sample : len (sample .kwargs ) == 0 ,
447+ reason = "start and end must be 1-D array, cannot be optional, due to ort 1.13 does not support yet" ,
448+ ),
423449)
424450
425451duplicate_opinfo (
@@ -523,6 +549,14 @@ def setUp(self) -> None:
523549 torch .manual_seed (42 )
524550 np .random .seed (42 )
525551
552+ def test_all_script_functions_are_onnx_functions (self ):
553+ for func_with_wrangler in OPINFO_FUNCTION_MAPPING_SCRIPTED .values ():
554+ if isinstance (func_with_wrangler , tuple ):
555+ func = func_with_wrangler [0 ]
556+ else :
557+ func = func_with_wrangler
558+ self .assertIsInstance (func , onnxscript .OnnxFunction )
559+
526560 @common_device_type .ops ( # type: ignore[misc]
527561 [info for info in OPS_DB if info .name in TESTED_OPS ],
528562 allowed_dtypes = TESTED_DTYPES ,
0 commit comments