@@ -471,17 +471,32 @@ def _where_input_wrangler(
471
471
),
472
472
TorchLibOpInfo ("ops.aten._log_softmax" , core_ops .aten__log_softmax ),
473
473
TorchLibOpInfo (
474
- "ops.aten._log_softmax_half" , core_ops .aten__log_softmax_half , trace_only = True
475
- ).xfail (
474
+ "ops.aten._log_softmax_half" ,
475
+ core_ops .aten__log_softmax_half ,
476
+ trace_only = True ,
477
+ tolerance = {torch .float16 : (1e-3 , 1e-3 )},
478
+ )
479
+ .xfail (
476
480
reason = "PyTorch does not implement _log_softmax for float16 on CPU" ,
477
481
dtypes = (torch .float16 ,),
482
+ enabled_if = version_utils .torch_older_than ("2.2" ),
483
+ )
484
+ .xfail (
485
+ dtypes = (torch .float16 ,),
486
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
487
+ test_class_name = "TestOutputConsistencyFullGraph" ,
478
488
),
479
489
TorchLibOpInfo ("ops.aten._softmax" , core_ops .aten__softmax , trace_only = True ),
480
- TorchLibOpInfo (
481
- "ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True
482
- ).xfail (
490
+ TorchLibOpInfo ("ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True )
491
+ .xfail (
483
492
reason = "PyTorch does not implement _softmax for float16 on CPU" ,
484
493
dtypes = (torch .float16 ,),
494
+ enabled_if = version_utils .torch_older_than ("2.2" ),
495
+ )
496
+ .xfail (
497
+ dtypes = (torch .float16 ,),
498
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
499
+ test_class_name = "TestOutputConsistencyFullGraph" ,
485
500
),
486
501
TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).skip (
487
502
matcher = lambda sample : not (len (sample .kwargs ) > 0 )
@@ -881,12 +896,28 @@ def _where_input_wrangler(
881
896
TorchLibOpInfo (
882
897
"log_softmax" ,
883
898
special_ops .aten_special_log_softmax ,
899
+ trace_only = True ,
884
900
tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (4e-4 , 6e-3 )},
885
- ).xfail (
901
+ )
902
+ .xfail (
903
+ dtypes = (torch .float16 ,),
904
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
905
+ test_class_name = "TestOutputConsistencyFullGraph" ,
906
+ )
907
+ .xfail (
886
908
variant_name = "with_dtype" ,
887
909
dtypes = (torch .float16 ,),
888
910
reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
889
911
test_class_name = "TestOutputConsistencyFullGraph" ,
912
+ )
913
+ .skip (
914
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
915
+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
916
+ )
917
+ .skip (
918
+ variant_name = "with_dtype" ,
919
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
920
+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
890
921
),
891
922
TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
892
923
TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
@@ -1361,12 +1392,28 @@ def _where_input_wrangler(
1361
1392
TorchLibOpInfo (
1362
1393
"softmax" ,
1363
1394
core_ops .aten_softmax ,
1395
+ trace_only = True ,
1364
1396
tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (3e-4 , 4e-4 )},
1365
- ).xfail (
1397
+ )
1398
+ .xfail (
1399
+ dtypes = (torch .float16 ,),
1400
+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1401
+ test_class_name = "TestOutputConsistencyFullGraph" ,
1402
+ )
1403
+ .xfail (
1366
1404
variant_name = "with_dtype" ,
1367
1405
dtypes = (torch .float16 ,),
1368
1406
reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1369
1407
test_class_name = "TestOutputConsistencyFullGraph" ,
1408
+ )
1409
+ .skip (
1410
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1411
+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1412
+ )
1413
+ .skip (
1414
+ variant_name = "with_dtype" ,
1415
+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1416
+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1370
1417
),
1371
1418
TorchLibOpInfo ("nn.functional.softplus" , nn_ops .aten_softplus ).xfail (
1372
1419
dtypes = (torch .float16 ,),
0 commit comments