@@ -471,21 +471,41 @@ def _where_input_wrangler(
471
471
),
472
472
TorchLibOpInfo ("ops.aten._log_softmax" , core_ops .aten__log_softmax ),
473
473
TorchLibOpInfo (
474
- "ops.aten._log_softmax_half" , core_ops .aten__log_softmax_half , trace_only = True
475
- ).xfail (
474
+ "ops.aten._log_softmax_half" ,
475
+ core_ops .aten__log_softmax_half ,
476
+ trace_only = True ,
477
+ tolerance = {torch .float16 : (1e-3 , 1e-3 )},
478
+ )
479
+ .xfail (
476
480
reason = "PyTorch does not implement _log_softmax for float16 on CPU" ,
477
481
dtypes = (torch .float16 ,),
482
+ enabled_if = version_utils .torch_older_than ("2.2" ),
483
+ )
484
+ .xfail (
485
+ dtypes = (torch .float16 ,),
486
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
487
+ test_class_name = "TestOutputConsistencyFullGraph" ,
478
488
),
479
489
TorchLibOpInfo ("ops.aten._softmax" , core_ops .aten__softmax , trace_only = True ),
480
- TorchLibOpInfo (
481
- "ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True
482
- ).xfail (
490
+ TorchLibOpInfo ("ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True )
491
+ .xfail (
483
492
reason = "PyTorch does not implement _softmax for float16 on CPU" ,
484
493
dtypes = (torch .float16 ,),
494
+ enabled_if = version_utils .torch_older_than ("2.2" ),
495
+ )
496
+ .xfail (
497
+ dtypes = (torch .float16 ,),
498
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
499
+ test_class_name = "TestOutputConsistencyFullGraph" ,
500
+ ),
501
+ TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).skip (
502
+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
503
+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
504
+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
485
505
),
486
- TorchLibOpInfo ("all_dim " , core_ops .aten_all_dim ). xfail (
487
- matcher = lambda sample : not ( len ( sample .kwargs ) > 0 ),
488
- reason = "this Aten overload only support one tensor as input and { dim,keepdim} as kwargs by design " ,
506
+ TorchLibOpInfo ("all_dims " , core_ops .aten_all_dims , trace_only = True ). skip (
507
+ matcher = lambda sample : not isinstance ( sample .kwargs . get ( "dim" ), tuple ),
508
+ reason = "this overload requires dim to be a tuple " ,
489
509
),
490
510
TorchLibOpInfo ("allclose" , core_ops .aten_allclose ),
491
511
TorchLibOpInfo (
@@ -501,7 +521,11 @@ def _where_input_wrangler(
501
521
TorchLibOpInfo ("acosh" , core_ops .aten_acosh ),
502
522
TorchLibOpInfo ("add" , core_ops .aten_add , tolerance = {torch .float16 : (1e-3 , 1e-3 )}),
503
523
TorchLibOpInfo ("add" , core_ops .aten_add_complex , complex = True , trace_only = True ),
504
- TorchLibOpInfo ("addbmm" , core_ops .aten_addbmm , tolerance = {torch .float32 : (2e-5 , 2e-5 )}),
524
+ TorchLibOpInfo (
525
+ "addbmm" ,
526
+ core_ops .aten_addbmm ,
527
+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch .float16 : (2e-1 , 2e-2 )},
528
+ ),
505
529
TorchLibOpInfo ("addcdiv" , core_ops .aten_addcdiv ),
506
530
TorchLibOpInfo ("addcmul" , core_ops .aten_addcmul , tolerance = {torch .float16 : (4e-3 , 3e-3 )}),
507
531
TorchLibOpInfo ("addmm" , core_ops .aten_addmm )
@@ -522,7 +546,7 @@ def _where_input_wrangler(
522
546
dtypes = (torch .int16 , torch .int32 , torch .int64 ),
523
547
reason = "ONNX Runtime does not support int inputs to Gemm" ,
524
548
),
525
- TorchLibOpInfo ("addmv" , core_ops .aten_addmv ),
549
+ TorchLibOpInfo ("addmv" , core_ops .aten_addmv , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
526
550
TorchLibOpInfo (
527
551
"addr" ,
528
552
core_ops .aten_addr ,
@@ -557,8 +581,13 @@ def _where_input_wrangler(
557
581
"any_dim" ,
558
582
core_ops .aten_any_dim ,
559
583
).skip (
560
- matcher = lambda sample : not (len (sample .kwargs ) > 0 ),
561
- reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design" ,
584
+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
585
+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
586
+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
587
+ ),
588
+ TorchLibOpInfo ("any_dims" , core_ops .aten_any_dims , trace_only = True ).skip (
589
+ matcher = lambda sample : not isinstance (sample .kwargs .get ("dim" ), tuple ),
590
+ reason = "this overload requires dim to be a tuple" ,
562
591
),
563
592
TorchLibOpInfo ("asin" , core_ops .aten_asin ),
564
593
TorchLibOpInfo ("asinh" , core_ops .aten_asinh ),
@@ -640,7 +669,7 @@ def _where_input_wrangler(
640
669
"https://github.com/microsoft/onnxscript/issues/1007"
641
670
),
642
671
),
643
- TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm ),
672
+ TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
644
673
TorchLibOpInfo ("bernoulli" , core_ops .aten_bernoulli , nondeterministic = True ),
645
674
TorchLibOpInfo (
646
675
# This string is a unique ID. In extra_opinfo.py, we
@@ -845,6 +874,12 @@ def _where_input_wrangler(
845
874
dtypes = (torch .int64 , torch .int32 ),
846
875
reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
847
876
)
877
+ .xfail (
878
+ variant_name = "tensor_overload" ,
879
+ dtypes = (torch .int64 , torch .int32 , torch .float16 ),
880
+ reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
881
+ enabled_if = not version_utils .torch_older_than ("2.2" ),
882
+ )
848
883
.xfail (
849
884
dtypes = (torch .float16 ,),
850
885
reason = "op 'Range' doesn't support float16." ,
@@ -861,17 +896,35 @@ def _where_input_wrangler(
861
896
TorchLibOpInfo (
862
897
"log_softmax" ,
863
898
special_ops .aten_special_log_softmax ,
899
+ trace_only = True ,
864
900
tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (4e-4 , 6e-3 )},
865
- ).xfail (
901
+ )
902
+ .xfail (
903
+ dtypes = (torch .float16 ,),
904
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
905
+ test_class_name = "TestOutputConsistencyFullGraph" ,
906
+ )
907
+ .xfail (
866
908
variant_name = "with_dtype" ,
867
909
dtypes = (torch .float16 ,),
868
910
reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
869
911
test_class_name = "TestOutputConsistencyFullGraph" ,
912
+ )
913
+ .skip (
914
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
915
+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
916
+ )
917
+ .skip (
918
+ variant_name = "with_dtype" ,
919
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
920
+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
870
921
),
871
922
TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
872
923
TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
873
924
TorchLibOpInfo ("logaddexp2" , core_ops .aten_logaddexp2 ),
874
- TorchLibOpInfo ("logcumsumexp" , core_ops .aten_logcumsumexp ),
925
+ TorchLibOpInfo (
926
+ "logcumsumexp" , core_ops .aten_logcumsumexp , tolerance = {torch .float16 : (1e-2 , 1e-1 )}
927
+ ),
875
928
TorchLibOpInfo ("logdet" , core_ops .aten_logdet ),
876
929
TorchLibOpInfo ("logsumexp" , core_ops .aten_logsumexp ),
877
930
TorchLibOpInfo ("lt" , core_ops .aten_lt ),
@@ -884,7 +937,7 @@ def _where_input_wrangler(
884
937
"matmul" ,
885
938
core_ops .aten_matmul ,
886
939
# Windows requires a more relaxed tolerance
887
- tolerance = {torch .float32 : (2e-5 , 2e-5 )},
940
+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch . float16 : ( 2e-3 , 2e-2 ) },
888
941
).skip (
889
942
matcher = lambda sample : torch .numel (sample .input ) == 0 ,
890
943
reason = "values of matmul of [m, 0] and [0, n] matrices are undefined" ,
@@ -1339,12 +1392,28 @@ def _where_input_wrangler(
1339
1392
TorchLibOpInfo (
1340
1393
"softmax" ,
1341
1394
core_ops .aten_softmax ,
1395
+ trace_only = True ,
1342
1396
tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (3e-4 , 4e-4 )},
1343
- ).xfail (
1397
+ )
1398
+ .xfail (
1399
+ dtypes = (torch .float16 ,),
1400
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1401
+ test_class_name = "TestOutputConsistencyFullGraph" ,
1402
+ )
1403
+ .xfail (
1344
1404
variant_name = "with_dtype" ,
1345
1405
dtypes = (torch .float16 ,),
1346
1406
reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1347
1407
test_class_name = "TestOutputConsistencyFullGraph" ,
1408
+ )
1409
+ .skip (
1410
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1411
+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1412
+ )
1413
+ .skip (
1414
+ variant_name = "with_dtype" ,
1415
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1416
+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1348
1417
),
1349
1418
TorchLibOpInfo ("nn.functional.softplus" , nn_ops .aten_softplus ).xfail (
1350
1419
dtypes = (torch .float16 ,),
@@ -1700,7 +1769,12 @@ def _where_input_wrangler(
1700
1769
variant_name = "empty_strides" ,
1701
1770
reason = "fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975" ,
1702
1771
),
1703
- TorchLibOpInfo ("native_batch_norm" , core_ops .aten_native_batch_norm , trace_only = True ),
1772
+ TorchLibOpInfo (
1773
+ "native_batch_norm" ,
1774
+ core_ops .aten_native_batch_norm ,
1775
+ trace_only = True ,
1776
+ tolerance = {torch .float16 : (9e-3 , 7e-4 )},
1777
+ ),
1704
1778
TorchLibOpInfo (
1705
1779
"ops.aten._native_batch_norm_legit" , core_ops .aten_native_batch_norm , trace_only = True
1706
1780
),
@@ -1719,9 +1793,11 @@ def _where_input_wrangler(
1719
1793
"ops.aten.native_group_norm" ,
1720
1794
core_ops .aten_native_group_norm ,
1721
1795
trace_only = True ,
1796
+ tolerance = {torch .float16 : (1e-2 , 7e-3 )},
1722
1797
).xfail (
1723
1798
dtypes = (torch .float16 ,),
1724
1799
reason = "fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly" ,
1800
+ enabled_if = version_utils .torch_older_than ("2.2" ),
1725
1801
),
1726
1802
TorchLibOpInfo (
1727
1803
"native_layer_norm" ,
@@ -1809,7 +1885,11 @@ def _where_input_wrangler(
1809
1885
matcher = lambda sample : len (sample .args ) != 1 ,
1810
1886
reason = "this overload is implemented for bias=None" ,
1811
1887
),
1812
- TorchLibOpInfo ("nn.functional.linear_bias" , nn_ops .aten_linear_bias ).skip (
1888
+ TorchLibOpInfo (
1889
+ "nn.functional.linear_bias" ,
1890
+ nn_ops .aten_linear_bias ,
1891
+ tolerance = {torch .float16 : (2e-1 , 4e-4 )},
1892
+ ).skip (
1813
1893
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
1814
1894
matcher = lambda sample : len (sample .args ) != 2 ,
1815
1895
reason = "this overload is implemented for bias!=None" ,
@@ -2059,8 +2139,8 @@ def _where_input_wrangler(
2059
2139
TorchLibOpInfo ("zeros_like" , core_ops .aten_zeros_like , trace_only = True ),
2060
2140
)
2061
2141
2062
- ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" ,))
2063
- ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" ,))
2142
+ ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" , "all_dims" ))
2143
+ ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" , "any_dims" ))
2064
2144
ops_test_common .duplicate_opinfo (OPS_DB , "arange" , ("arange_start" , "arange_start_step" ))
2065
2145
ops_test_common .duplicate_opinfo (OPS_DB , "argmax" , ("argmax_dim" ,))
2066
2146
ops_test_common .duplicate_opinfo (OPS_DB , "argmin" , ("argmin_dim" ,))
0 commit comments