@@ -501,7 +501,11 @@ def _where_input_wrangler(
501
501
TorchLibOpInfo ("acosh" , core_ops .aten_acosh ),
502
502
TorchLibOpInfo ("add" , core_ops .aten_add , tolerance = {torch .float16 : (1e-3 , 1e-3 )}),
503
503
TorchLibOpInfo ("add" , core_ops .aten_add_complex , complex = True , trace_only = True ),
504
- TorchLibOpInfo ("addbmm" , core_ops .aten_addbmm , tolerance = {torch .float32 : (2e-5 , 2e-5 )}),
504
+ TorchLibOpInfo (
505
+ "addbmm" ,
506
+ core_ops .aten_addbmm ,
507
+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch .float16 : (2e-1 , 2e-2 )},
508
+ ),
505
509
TorchLibOpInfo ("addcdiv" , core_ops .aten_addcdiv ),
506
510
TorchLibOpInfo ("addcmul" , core_ops .aten_addcmul , tolerance = {torch .float16 : (4e-3 , 3e-3 )}),
507
511
TorchLibOpInfo ("addmm" , core_ops .aten_addmm )
@@ -522,7 +526,7 @@ def _where_input_wrangler(
522
526
dtypes = (torch .int16 , torch .int32 , torch .int64 ),
523
527
reason = "ONNX Runtime does not support int inputs to Gemm" ,
524
528
),
525
- TorchLibOpInfo ("addmv" , core_ops .aten_addmv ),
529
+ TorchLibOpInfo ("addmv" , core_ops .aten_addmv , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
526
530
TorchLibOpInfo (
527
531
"addr" ,
528
532
core_ops .aten_addr ,
@@ -640,7 +644,7 @@ def _where_input_wrangler(
640
644
"https://github.com/microsoft/onnxscript/issues/1007"
641
645
),
642
646
),
643
- TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm ),
647
+ TorchLibOpInfo ("baddbmm" , core_ops .aten_baddbmm , tolerance = { torch . float16 : ( 1e-3 , 1e-2 )} ),
644
648
TorchLibOpInfo ("bernoulli" , core_ops .aten_bernoulli , nondeterministic = True ),
645
649
TorchLibOpInfo (
646
650
# This string is a unique ID. In extra_opinfo.py, we
@@ -845,6 +849,12 @@ def _where_input_wrangler(
845
849
dtypes = (torch .int64 , torch .int32 ),
846
850
reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
847
851
)
852
+ .xfail (
853
+ variant_name = "tensor_overload" ,
854
+ dtypes = (torch .int64 , torch .int32 , torch .float16 ),
855
+ reason = "fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854" ,
856
+ enabled_if = not version_utils .torch_older_than ("2.1" ),
857
+ )
848
858
.xfail (
849
859
dtypes = (torch .float16 ,),
850
860
reason = "op 'Range' doesn't support float16." ,
@@ -884,7 +894,7 @@ def _where_input_wrangler(
884
894
"matmul" ,
885
895
core_ops .aten_matmul ,
886
896
# Windows requires a more relaxed tolerance
887
- tolerance = {torch .float32 : (2e-5 , 2e-5 )},
897
+ tolerance = {torch .float32 : (2e-5 , 2e-5 ), torch . float16 : ( 2e-3 , 2e-2 ) },
888
898
).skip (
889
899
matcher = lambda sample : torch .numel (sample .input ) == 0 ,
890
900
reason = "values of matmul of [m, 0] and [0, n] matrices are undefined" ,
@@ -1700,7 +1710,12 @@ def _where_input_wrangler(
1700
1710
variant_name = "empty_strides" ,
1701
1711
reason = "fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975" ,
1702
1712
),
1703
- TorchLibOpInfo ("native_batch_norm" , core_ops .aten_native_batch_norm , trace_only = True ),
1713
+ TorchLibOpInfo (
1714
+ "native_batch_norm" ,
1715
+ core_ops .aten_native_batch_norm ,
1716
+ trace_only = True ,
1717
+ tolerance = {torch .float16 : (9e-3 , 7e-4 )},
1718
+ ),
1704
1719
TorchLibOpInfo (
1705
1720
"ops.aten._native_batch_norm_legit" , core_ops .aten_native_batch_norm , trace_only = True
1706
1721
),
@@ -1719,9 +1734,11 @@ def _where_input_wrangler(
1719
1734
"ops.aten.native_group_norm" ,
1720
1735
core_ops .aten_native_group_norm ,
1721
1736
trace_only = True ,
1737
+ tolerance = {torch .float16 : (1e-2 , 7e-3 )},
1722
1738
).xfail (
1723
1739
dtypes = (torch .float16 ,),
1724
1740
reason = "fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly" ,
1741
+ enabled_if = version_utils .torch_older_than ("2.2" ),
1725
1742
),
1726
1743
TorchLibOpInfo (
1727
1744
"native_layer_norm" ,
@@ -1809,7 +1826,11 @@ def _where_input_wrangler(
1809
1826
matcher = lambda sample : len (sample .args ) != 1 ,
1810
1827
reason = "this overload is implemented for bias=None" ,
1811
1828
),
1812
- TorchLibOpInfo ("nn.functional.linear_bias" , nn_ops .aten_linear_bias ).skip (
1829
+ TorchLibOpInfo (
1830
+ "nn.functional.linear_bias" ,
1831
+ nn_ops .aten_linear_bias ,
1832
+ tolerance = {torch .float16 : (2e-1 , 4e-4 )},
1833
+ ).skip (
1813
1834
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
1814
1835
matcher = lambda sample : len (sample .args ) != 2 ,
1815
1836
reason = "this overload is implemented for bias!=None" ,
0 commit comments