@@ -526,7 +526,7 @@ def _where_input_wrangler(
526
526
core_ops .aten_addbmm ,
527
527
tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch .float16 : (2e-1 , 2e-2 )},
528
528
),
529
- TorchLibOpInfo ("addcdiv" , core_ops .aten_addcdiv ),
529
+ TorchLibOpInfo ("addcdiv" , core_ops .aten_addcdiv , tolerance = { torch . float16 : ( 3e-2 , 1e-3 )} ),
530
530
TorchLibOpInfo ("addcmul" , core_ops .aten_addcmul , tolerance = {torch .float16 : (4e-3 , 3e-3 )}),
531
531
TorchLibOpInfo ("addmm" , core_ops .aten_addmm )
532
532
.xfail (
@@ -592,7 +592,7 @@ def _where_input_wrangler(
592
592
TorchLibOpInfo ("asin" , core_ops .aten_asin ),
593
593
TorchLibOpInfo ("asinh" , core_ops .aten_asinh ),
594
594
TorchLibOpInfo ("atan" , core_ops .aten_atan ),
595
- TorchLibOpInfo ("atan2" , core_ops .aten_atan2 ),
595
+ TorchLibOpInfo ("atan2" , core_ops .aten_atan2 , tolerance = { torch . float16 : ( 1e-3 , 1e-3 )} ),
596
596
TorchLibOpInfo ("atanh" , core_ops .aten_atanh ),
597
597
TorchLibOpInfo ("atleast_1d" , core_ops .aten_atleast_1d ).skip (
598
598
matcher = lambda sample : isinstance (sample .input , (list , tuple )),
@@ -737,7 +737,7 @@ def _where_input_wrangler(
737
737
# TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB
738
738
TorchLibOpInfo ("cos" , core_ops .aten_cos ),
739
739
TorchLibOpInfo ("cosh" , core_ops .aten_cosh ),
740
- TorchLibOpInfo ("cross" , core_ops .aten_cross ),
740
+ TorchLibOpInfo ("cross" , core_ops .aten_cross , tolerance = { torch . float16 : ( 6e-3 , 3e-3 )} ),
741
741
# TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB
742
742
TorchLibOpInfo ("diagonal" , core_ops .aten_diagonal , trace_only = True ),
743
743
TorchLibOpInfo ("diagonal_bool" , core_ops .aten_diagonal_bool , trace_only = True ),
@@ -920,8 +920,10 @@ def _where_input_wrangler(
920
920
reason = "fixme: LogSoftMax does not support empty tensor as input" ,
921
921
),
922
922
TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
923
- TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
924
- TorchLibOpInfo ("logaddexp2" , core_ops .aten_logaddexp2 ),
923
+ TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp , tolerance = {torch .float16 : (1 , 1e-4 )}),
924
+ TorchLibOpInfo (
925
+ "logaddexp2" , core_ops .aten_logaddexp2 , tolerance = {torch .float16 : (2e-2 , 6e-4 )}
926
+ ),
925
927
TorchLibOpInfo (
926
928
"logcumsumexp" , core_ops .aten_logcumsumexp , tolerance = {torch .float16 : (1e-2 , 1e-1 )}
927
929
),
@@ -1087,10 +1089,16 @@ def _where_input_wrangler(
1087
1089
TorchLibOpInfo (
1088
1090
"nn.functional.adaptive_avg_pool1d" ,
1089
1091
nn_ops .aten_adaptive_avg_pool1d ,
1090
- ).xfail (
1092
+ )
1093
+ .xfail (
1091
1094
# Shape should be [N, C, D1]
1092
1095
matcher = lambda sample : sample .args [0 ] not in {1 , (1 ,)},
1093
1096
reason = "only global pooling is supported; only batched inputs are supported" ,
1097
+ )
1098
+ .xfail (
1099
+ reason = "ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449" ,
1100
+ dtypes = (torch .float16 ,),
1101
+ test_class_name = "TestOutputConsistencyEager" ,
1094
1102
),
1095
1103
TorchLibOpInfo (
1096
1104
"nn.functional.adaptive_avg_pool2d" ,
@@ -1718,7 +1726,9 @@ def _where_input_wrangler(
1718
1726
dtypes = (torch .int64 ,),
1719
1727
reason = "fixme: ORT `LayerNormKernelImpl` not implemented for int64" ,
1720
1728
),
1721
- TorchLibOpInfo ("logit" , core_ops .aten_logit , trace_only = True ),
1729
+ TorchLibOpInfo (
1730
+ "logit" , core_ops .aten_logit , trace_only = True , tolerance = {torch .float16 : (1e-1 , 7e-4 )}
1731
+ ),
1722
1732
TorchLibOpInfo ("max_dim" , core_ops .aten_max_dim )
1723
1733
.skip (
1724
1734
variant_name = "reduction_with_dim" ,
@@ -1869,7 +1879,7 @@ def _where_input_wrangler(
1869
1879
reason = "String padding is not accepted by aten::conv2d" ,
1870
1880
),
1871
1881
TorchLibOpInfo (
1872
- "nn.functional .conv3d" ,
1882
+ "ops.aten .conv3d" ,
1873
1883
core_ops .aten_conv3d ,
1874
1884
trace_only = True ,
1875
1885
tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 )},
@@ -1974,6 +1984,16 @@ def _where_input_wrangler(
1974
1984
.skip (
1975
1985
matcher = lambda sample : sample .kwargs .get ("dropout_p" ) != 0.0 ,
1976
1986
reason = "dropout is random so the results do not match" ,
1987
+ )
1988
+ .xfail (
1989
+ dtypes = (torch .float16 ,),
1990
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1991
+ test_class_name = "TestOutputConsistencyFullGraph" ,
1992
+ )
1993
+ .xfail (
1994
+ reason = "fixme: ORT fails on type mismatch in Add" ,
1995
+ dtypes = (torch .float16 ,),
1996
+ test_class_name = "TestOutputConsistencyEager" ,
1977
1997
),
1978
1998
TorchLibOpInfo (
1979
1999
"ops.aten._scaled_dot_product_flash_attention" ,
@@ -2000,6 +2020,16 @@ def _where_input_wrangler(
2000
2020
.skip (
2001
2021
matcher = lambda sample : sample .kwargs .get ("dropout_p" ) != 0.0 ,
2002
2022
reason = "dropout is random so the results do not match" ,
2023
+ )
2024
+ .xfail (
2025
+ dtypes = (torch .float16 ,),
2026
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
2027
+ test_class_name = "TestOutputConsistencyFullGraph" ,
2028
+ )
2029
+ .xfail (
2030
+ reason = "fixme: ORT fails on type mismatch in Add" ,
2031
+ dtypes = (torch .float16 ,),
2032
+ test_class_name = "TestOutputConsistencyEager" ,
2003
2033
),
2004
2034
TorchLibOpInfo (
2005
2035
"nn.functional.upsample_bilinear2d" ,
0 commit comments