@@ -483,9 +483,14 @@ def _where_input_wrangler(
483
483
reason = "PyTorch does not implement _softmax for float16 on CPU" ,
484
484
dtypes = (torch .float16 ,),
485
485
),
486
- TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).xfail (
487
- matcher = lambda sample : not (len (sample .kwargs ) > 0 ),
488
- reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design" ,
486
+ TorchLibOpInfo ("all_dim" , core_ops .aten_all_dim ).skip (
487
+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
488
+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
489
+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
490
+ ),
491
+ TorchLibOpInfo ("all_dims" , core_ops .aten_all_dims , trace_only = True ).skip (
492
+ matcher = lambda sample : not isinstance (sample .kwargs .get ("dim" ), tuple ),
493
+ reason = "this overload requires dim to be a tuple" ,
489
494
),
490
495
TorchLibOpInfo ("allclose" , core_ops .aten_allclose ),
491
496
TorchLibOpInfo (
@@ -561,8 +566,13 @@ def _where_input_wrangler(
561
566
"any_dim" ,
562
567
core_ops .aten_any_dim ,
563
568
).skip (
564
- matcher = lambda sample : not (len (sample .kwargs ) > 0 ),
565
- reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design" ,
569
+ matcher = lambda sample : not (len (sample .kwargs ) > 0 )
570
+ or isinstance (sample .kwargs .get ("dim" ), tuple ),
571
+ reason = "this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer" ,
572
+ ),
573
+ TorchLibOpInfo ("any_dims" , core_ops .aten_any_dims , trace_only = True ).skip (
574
+ matcher = lambda sample : not isinstance (sample .kwargs .get ("dim" ), tuple ),
575
+ reason = "this overload requires dim to be a tuple" ,
566
576
),
567
577
TorchLibOpInfo ("asin" , core_ops .aten_asin ),
568
578
TorchLibOpInfo ("asinh" , core_ops .aten_asinh ),
@@ -881,7 +891,9 @@ def _where_input_wrangler(
881
891
TorchLibOpInfo ("log2" , core_ops .aten_log2 ),
882
892
TorchLibOpInfo ("logaddexp" , core_ops .aten_logaddexp ),
883
893
TorchLibOpInfo ("logaddexp2" , core_ops .aten_logaddexp2 ),
884
- TorchLibOpInfo ("logcumsumexp" , core_ops .aten_logcumsumexp ),
894
+ TorchLibOpInfo (
895
+ "logcumsumexp" , core_ops .aten_logcumsumexp , tolerance = {torch .float16 : (1e-2 , 1e-1 )}
896
+ ),
885
897
TorchLibOpInfo ("logdet" , core_ops .aten_logdet ),
886
898
TorchLibOpInfo ("logsumexp" , core_ops .aten_logsumexp ),
887
899
TorchLibOpInfo ("lt" , core_ops .aten_lt ),
@@ -2080,8 +2092,8 @@ def _where_input_wrangler(
2080
2092
TorchLibOpInfo ("zeros_like" , core_ops .aten_zeros_like , trace_only = True ),
2081
2093
)
2082
2094
2083
- ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" ,))
2084
- ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" ,))
2095
+ ops_test_common .duplicate_opinfo (OPS_DB , "all" , ("all_dim" , "all_dims" ))
2096
+ ops_test_common .duplicate_opinfo (OPS_DB , "any" , ("any_dim" , "any_dims" ))
2085
2097
ops_test_common .duplicate_opinfo (OPS_DB , "arange" , ("arange_start" , "arange_start_step" ))
2086
2098
ops_test_common .duplicate_opinfo (OPS_DB , "argmax" , ("argmax_dim" ,))
2087
2099
ops_test_common .duplicate_opinfo (OPS_DB , "argmin" , ("argmin_dim" ,))
0 commit comments