@@ -922,30 +922,41 @@ def get_split_with_sizes_inputs():
922
922
Test = namedtuple ("VkSliceTest" , ["self" , "sizes" , "dim" ])
923
923
test_cases = [
924
924
# Split on Width
925
+ Test (self = (S1 , 7 , 10 , 11 ), sizes = [1 , 3 , 2 , 5 ], dim = 3 ),
925
926
Test (self = (S1 , 7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 3 ),
927
+ Test (self = (7 , 10 , 11 ), sizes = [1 , 3 , 2 , 5 ], dim = 2 ),
926
928
Test (self = (7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 2 ),
929
+ Test (self = (7 , 10 , 11 ), sizes = [3 , 8 ], dim = 2 ),
927
930
Test (self = (7 , 10 , 10 ), sizes = [1 , 9 ], dim = 2 ),
928
931
Test (self = (10 , 10 ), sizes = [1 , 9 ], dim = 1 ),
929
932
Test (self = (10 ,), sizes = [1 , 9 ], dim = 0 ),
930
933
# Split on Height
934
+ Test (self = (S1 , 7 , 11 , 10 ), sizes = [1 , 3 , 2 , 5 ], dim = 2 ),
931
935
Test (self = (S1 , 7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 2 ),
936
+ Test (self = (7 , 11 , 10 ), sizes = [1 , 3 , 2 , 5 ], dim = 1 ),
932
937
Test (self = (7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 1 ),
938
+ Test (self = (7 , 11 , 11 ), sizes = [3 , 8 ], dim = 1 ),
933
939
Test (self = (7 , 10 , 10 ), sizes = [10 ], dim = 1 ),
934
940
Test (self = (7 , 6 , 10 ), sizes = [1 , 1 , 1 , 1 , 1 , 1 ], dim = 1 ),
935
941
Test (self = (10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 0 ),
936
942
# Split on Batch
937
943
Test (self = (10 , 7 , 10 , 10 ), sizes = [3 , 6 , 1 ], dim = 0 ),
938
944
Test (self = (10 , 7 , 10 , 10 ), sizes = [10 ], dim = 0 ),
939
945
# Split on Channel
946
+ Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 5 , 2 , 3 ], dim = 1 ),
940
947
Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 6 , 1 , 3 ], dim = 1 ),
948
+ Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 2 , 2 , 5 , 1 ], dim = 1 ),
941
949
Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 3 , 3 , 3 , 1 ], dim = 1 ),
950
+ Test (self = (13 , 4 , 8 ), sizes = [3 , 5 , 2 , 1 , 2 ], dim = 0 ),
942
951
Test (self = (13 , 4 , 8 ), sizes = [3 , 3 , 3 , 3 , 1 ], dim = 0 ),
943
952
Test (self = (13 , 4 , 8 ), sizes = [2 , 9 , 2 ], dim = 0 ),
944
953
Test (self = (13 , 4 , 8 ), sizes = [13 ], dim = 0 ),
945
954
]
946
955
test_suite = VkTestSuite ([tuple (tc ) for tc in test_cases ])
947
956
948
957
test_suite .layouts = [
958
+ "utils::kWidthPacked" ,
959
+ "utils::kHeightPacked" ,
949
960
"utils::kChannelsPacked" ,
950
961
]
951
962
test_suite .data_gen = "make_seq_tensor"
@@ -997,6 +1008,8 @@ def get_split_tensor_inputs():
997
1008
)
998
1009
999
1010
test_suite .layouts = [
1011
+ "utils::kWidthPacked" ,
1012
+ "utils::kHeightPacked" ,
1000
1013
"utils::kChannelsPacked" ,
1001
1014
]
1002
1015
test_suite .data_gen = "make_seq_tensor"
0 commit comments