@@ -471,17 +471,32 @@ def _where_input_wrangler(
471471 ),
472472 TorchLibOpInfo ("ops.aten._log_softmax" , core_ops .aten__log_softmax ),
473473 TorchLibOpInfo (
474- "ops.aten._log_softmax_half" , core_ops .aten__log_softmax_half , trace_only = True
475- ).xfail (
474+ "ops.aten._log_softmax_half" ,
475+ core_ops .aten__log_softmax_half ,
476+ trace_only = True ,
477+ tolerance = {torch .float16 : (1e-3 , 1e-3 )},
478+ )
479+ .xfail (
476480 reason = "PyTorch does not implement _log_softmax for float16 on CPU" ,
477481 dtypes = (torch .float16 ,),
482+ enabled_if = version_utils .torch_older_than ("2.2" ),
483+ )
484+ .xfail (
485+ dtypes = (torch .float16 ,),
486+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
487+ test_class_name = "TestOutputConsistencyFullGraph" ,
478488 ),
479489 TorchLibOpInfo ("ops.aten._softmax" , core_ops .aten__softmax , trace_only = True ),
480- TorchLibOpInfo (
481- "ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True
482- ).xfail (
490+ TorchLibOpInfo ("ops.aten._softmax_half" , core_ops .aten__softmax_half , trace_only = True )
491+ .xfail (
483492 reason = "PyTorch does not implement _softmax for float16 on CPU" ,
484493 dtypes = (torch .float16 ,),
494+ enabled_if = version_utils .torch_older_than ("2.2" ),
495+ )
496+ .xfail (
497+ dtypes = (torch .float16 ,),
498+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
499+ test_class_name = "TestOutputConsistencyFullGraph" ,
485500 ),
486501 TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).skip (
487502 matcher = lambda sample : not (len (sample .kwargs ) > 0 )
@@ -881,12 +896,28 @@ def _where_input_wrangler(
881896 TorchLibOpInfo (
882897 "log_softmax" ,
883898 special_ops .aten_special_log_softmax ,
899+ trace_only = True ,
884900 tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (4e-4 , 6e-3 )},
885- ).xfail (
901+ )
902+ .xfail (
903+ dtypes = (torch .float16 ,),
904+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
905+ test_class_name = "TestOutputConsistencyFullGraph" ,
906+ )
907+ .xfail (
886908 variant_name = "with_dtype" ,
887909 dtypes = (torch .float16 ,),
888910 reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
889911 test_class_name = "TestOutputConsistencyFullGraph" ,
912+ )
913+ .skip (
914+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
915+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
916+ )
917+ .skip (
918+ variant_name = "with_dtype" ,
919+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
920+ reason = "fixme: LogSoftMax does not support empty tensor as input" ,
890921 ),
891922 TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
892923 TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
@@ -1361,12 +1392,28 @@ def _where_input_wrangler(
13611392 TorchLibOpInfo (
13621393 "softmax" ,
13631394 core_ops .aten_softmax ,
1395+ trace_only = True ,
13641396 tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 ), torch .float16 : (3e-4 , 4e-4 )},
1365- ).xfail (
1397+ )
1398+ .xfail (
1399+ dtypes = (torch .float16 ,),
1400+ reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
1401+ test_class_name = "TestOutputConsistencyFullGraph" ,
1402+ )
1403+ .xfail (
13661404 variant_name = "with_dtype" ,
13671405 dtypes = (torch .float16 ,),
13681406 reason = "fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438" ,
13691407 test_class_name = "TestOutputConsistencyFullGraph" ,
1408+ )
1409+ .skip (
1410+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1411+ reason = "fixme: SoftMax does not support empty tensor as input" ,
1412+ )
1413+ .skip (
1414+ variant_name = "with_dtype" ,
1415+ matcher = lambda sample : len (sample .input .shape ) == 0 ,
1416+ reason = "fixme: SoftMax does not support empty tensor as input" ,
13701417 ),
13711418 TorchLibOpInfo ("nn.functional.softplus" , nn_ops .aten_softplus ).xfail (
13721419 dtypes = (torch .float16 ,),
0 commit comments