@@ -471,21 +471,41 @@ def _where_input_wrangler(
471471 ),
472472 TorchLibOpInfo ("ops.aten._log_softmax" , core_ops .aten__log_softmax ),
473473 TorchLibOpInfo (
474- "ops.aten._log_softmax_half" , core_ops .aten__log_softmax_half , trace_only = True
475- ).xfail (
474+ "ops.aten._log_softmax_half" ,
475+ core_ops .aten__log_softmax_half ,
476+ trace_only = True ,
477+ tolerance = {torch .float16 : (1e-3 , 1e-3 )},
478+ )
479+ .xfail (
476480 reason = "PyTorch does not implement _log_softmax for float16 on CPU" ,
477481 dtypes = (torch .float16 ,),
482+ enabled_if = version_utils .torch_older_than ("2.2" ),
483+ )
484+ .xfail (
485+ dtypes = (torch .float16 ,),
486+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
487+ test_class_name = "TestOutputConsistencyFullGraph" ,
478488 ),
479489 TorchLibOpInfo ("ops.aten._softmax" , core_ops .aten__softmax , trace_only = True ),
480- TorchLibOpInfo (
481- "ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True
482- ).xfail (
490+ TorchLibOpInfo ("ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True )
491+ .xfail (
483492 reason = "PyTorch does not implement _softmax for float16 on CPU" ,
484493 dtypes = (torch .float16 ,),
494+ enabled_if = version_utils .torch_older_than ("2.2" ),
495+ )
496+ .xfail (
497+ dtypes = (torch .float16 ,),
498+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
499+ test_class_name = "TestOutputConsistencyFullGraph" ,
500+ ),
501+ TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).skip (
502+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
503+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
504+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
485505 ),
486- TorchLibOpInfo ("all_dim " , core_ops .aten_all_dim ). xfail (
487- matcher = lambda sample : not ( len ( sample .kwargs ) > 0 ),
488- reason = "this Aten overload only support one tensor as input and { dim,keepdim} as kwargs by design " ,
506+ TorchLibOpInfo ("all_dims " , core_ops .aten_all_dims , trace_only = True ). skip (
507+ matcher = lambda sample : not isinstance ( sample .kwargs . get ( "dim" ), tuple ),
508+ reason = "this overload requires dim to be a tuple " ,
489509 ),
490510 TorchLibOpInfo ("allclose" , core_ops .aten_allclose ),
491511 TorchLibOpInfo (
@@ -501,7 +521,11 @@ def _where_input_wrangler(
501521 TorchLibOpInfo ("acosh" , core_ops .aten_acosh ),
502522 TorchLibOpInfo ("add" , core_ops .aten_add , tolerance = {torch .float16 : (1e-3 , 1e-3 )}),
503523 TorchLibOpInfo ("add" , core_ops .aten_add_complex , complex = True , trace_only = True ),
504- TorchLibOpInfo ("addbmm" , core_ops .aten_addbmm , tolerance = {torch .float32 : (2e-5 , 2e-5 )}),
524+ TorchLibOpInfo (
525+ "addbmm" ,
526+ core_ops .aten_addbmm ,
527+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch .float16 : (2e-1 , 2e-2 )},
528+ ),
505529 TorchLibOpInfo ("addcdiv" , core_ops .aten_addcdiv ),
506530 TorchLibOpInfo ("addcmul" , core_ops .aten_addcmul , tolerance = {torch .float16 : (4e-3 , 3e-3 )}),
507531 TorchLibOpInfo ("addmm" , core_ops .aten_addmm )
@@ -522,7 +546,7 @@ def _where_input_wrangler(
522546 dtypes = (torch .int16 , torch .int32 , torch .int64 ),
523547 reason = "ONNX Runtime does not support int inputs to Gemm" ,
524548 ),
525- TorchLibOpInfo ("addmv" , core_ops .aten_addmv ),
549+ TorchLibOpInfo ("addmv" , core_ops .aten_addmv , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
526550 TorchLibOpInfo (
527551 "addr" ,
528552 core_ops .aten_addr ,
@@ -557,8 +581,13 @@ def _where_input_wrangler(
557581 "any_dim" ,
558582 core_ops .aten_any_dim ,
559583 ).skip (
560- matcher = lambda sample : not (len (sample .kwargs ) > 0 ),
561- reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design" ,
584+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
585+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
586+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
587+ ),
588+ TorchLibOpInfo ("any_dims" , core_ops .aten_any_dims , trace_only = True ).skip (
589+ matcher = lambda sample : not isinstance (sample .kwargs .get ("dim" ), tuple ),
590+ reason = "this overload requires dim to be a tuple" ,
562591 ),
563592 TorchLibOpInfo ("asin" , core_ops .aten_asin ),
564593 TorchLibOpInfo ("asinh" , core_ops .aten_asinh ),
@@ -640,7 +669,7 @@ def _where_input_wrangler(
640669 "https://github.com/microsoft/onnxscript/issues/1007"
641670 ),
642671 ),
643- TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm ),
672+ TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
644673 TorchLibOpInfo ("bernoulli" , core_ops .aten_bernoulli , nondeterministic = True ),
645674 TorchLibOpInfo (
646675 # This string is a unique ID. In extra_opinfo.py, we
@@ -845,6 +874,12 @@ def _where_input_wrangler(
845874 dtypes = (torch .int64 , torch .int32 ),
846875 reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
847876 )
877+ .xfail (
878+ variant_name = "tensor_overload" ,
879+ dtypes = (torch .int64 , torch .int32 , torch .float16 ),
880+ reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
881+ enabled_if = not version_utils .torch_older_than ("2.2" ),
882+ )
848883 .xfail (
849884 dtypes = (torch .float16 ,),
850885 reason = "op 'Range' doesn't support float16." ,
@@ -861,17 +896,35 @@ def _where_input_wrangler(
861896 TorchLibOpInfo (
862897 "log_softmax" ,
863898 special_ops .aten_special_log_softmax ,
899+ trace_only = True ,
864900 tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (4e-4 , 6e-3 )},
865- ).xfail (
901+ )
902+ .xfail (
903+ dtypes = (torch .float16 ,),
904+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
905+ test_class_name = "TestOutputConsistencyFullGraph" ,
906+ )
907+ .xfail (
866908 variant_name = "with_dtype" ,
867909 dtypes = (torch .float16 ,),
868910 reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
869911 test_class_name = "TestOutputConsistencyFullGraph" ,
912+ )
913+ .skip (
914+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
915+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
916+ )
917+ .skip (
918+ variant_name = "with_dtype" ,
919+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
920+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
870921 ),
871922 TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
872923 TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
873924 TorchLibOpInfo ("logaddexp2" , core_ops .aten_logaddexp2 ),
874- TorchLibOpInfo ("logcumsumexp" , core_ops .aten_logcumsumexp ),
925+ TorchLibOpInfo (
926+ "logcumsumexp" , core_ops .aten_logcumsumexp , tolerance = {torch .float16 : (1e-2 , 1e-1 )}
927+ ),
875928 TorchLibOpInfo ("logdet" , core_ops .aten_logdet ),
876929 TorchLibOpInfo ("logsumexp" , core_ops .aten_logsumexp ),
877930 TorchLibOpInfo ("lt" , core_ops .aten_lt ),
@@ -884,7 +937,7 @@ def _where_input_wrangler(
884937 "matmul" ,
885938 core_ops .aten_matmul ,
886939 # Windows requires a more relaxed tolerance
887- tolerance = {torch .float32 : (2e-5 , 2e-5 )},
940+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch . float16 : ( 2e-3 , 2e-2 ) },
888941 ).skip (
889942 matcher = lambda sample : torch .numel (sample .input ) == 0 ,
890943 reason = "values of matmul of [m, 0] and [0, n] matrices are undefined" ,
@@ -1339,12 +1392,28 @@ def _where_input_wrangler(
13391392 TorchLibOpInfo (
13401393 "softmax" ,
13411394 core_ops .aten_softmax ,
1395+ trace_only = True ,
13421396 tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (3e-4 , 4e-4 )},
1343- ).xfail (
1397+ )
1398+ .xfail (
1399+ dtypes = (torch .float16 ,),
1400+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1401+ test_class_name = "TestOutputConsistencyFullGraph" ,
1402+ )
1403+ .xfail (
13441404 variant_name = "with_dtype" ,
13451405 dtypes = (torch .float16 ,),
13461406 reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
13471407 test_class_name = "TestOutputConsistencyFullGraph" ,
1408+ )
1409+ .skip (
1410+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1411+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1412+ )
1413+ .skip (
1414+ variant_name = "with_dtype" ,
1415+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1416+ reason = "fixme: SoftMax does not support empty tensor as input" ,
13481417 ),
13491418 TorchLibOpInfo ("nn.functional.softplus" , nn_ops .aten_softplus ).xfail (
13501419 dtypes = (torch .float16 ,),
@@ -1700,7 +1769,12 @@ def _where_input_wrangler(
17001769 variant_name = "empty_strides" ,
17011770 reason = "fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975" ,
17021771 ),
1703- TorchLibOpInfo ("native_batch_norm" , core_ops .aten_native_batch_norm , trace_only = True ),
1772+ TorchLibOpInfo (
1773+ "native_batch_norm" ,
1774+ core_ops .aten_native_batch_norm ,
1775+ trace_only = True ,
1776+ tolerance = {torch .float16 : (9e-3 , 7e-4 )},
1777+ ),
17041778 TorchLibOpInfo (
17051779 "ops.aten._native_batch_norm_legit" , core_ops .aten_native_batch_norm , trace_only = True
17061780 ),
@@ -1719,9 +1793,11 @@ def _where_input_wrangler(
17191793 "ops.aten.native_group_norm" ,
17201794 core_ops .aten_native_group_norm ,
17211795 trace_only = True ,
1796+ tolerance = {torch .float16 : (1e-2 , 7e-3 )},
17221797 ).xfail (
17231798 dtypes = (torch .float16 ,),
17241799 reason = "fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly" ,
1800+ enabled_if = version_utils .torch_older_than ("2.2" ),
17251801 ),
17261802 TorchLibOpInfo (
17271803 "native_layer_norm" ,
@@ -1809,7 +1885,11 @@ def _where_input_wrangler(
18091885 matcher = lambda sample : len (sample .args ) != 1 ,
18101886 reason = "this overload is implemented for bias=None" ,
18111887 ),
1812- TorchLibOpInfo ("nn.functional.linear_bias" , nn_ops .aten_linear_bias ).skip (
1888+ TorchLibOpInfo (
1889+ "nn.functional.linear_bias" ,
1890+ nn_ops .aten_linear_bias ,
1891+ tolerance = {torch .float16 : (2e-1 , 4e-4 )},
1892+ ).skip (
18131893 # input: input, args: weight, bias; so len(args) == 2 means bias is provided
18141894 matcher = lambda sample : len (sample .args ) != 2 ,
18151895 reason = "this overload is implemented for bias!=None" ,
@@ -2059,8 +2139,8 @@ def _where_input_wrangler(
20592139 TorchLibOpInfo ("zeros_like" , core_ops .aten_zeros_like , trace_only = True ),
20602140)
20612141
2062- ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" ,))
2063- ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" ,))
2142+ ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" , "all_dims" ))
2143+ ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" , "any_dims" ))
20642144ops_test_common .duplicate_opinfo (OPS_DB , "arange" , ("arange_start" , "arange_start_step" ))
20652145ops_test_common .duplicate_opinfo (OPS_DB , "argmax" , ("argmax_dim" ,))
20662146ops_test_common .duplicate_opinfo (OPS_DB , "argmin" , ("argmin_dim" ,))
0 commit comments