Skip to content

Commit 818f55b

Browse files
pytorchbottrivedivivek
authored andcommitted
[ET-VK] Adding all tensor packing support to split op. (pytorch#9439)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#9345 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/66/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/66/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/65/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/66/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 4557f7f commit 818f55b

File tree

3 files changed

+58
-49
lines changed

3 files changed

+58
-49
lines changed

backends/vulkan/op_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,6 @@ def register_view_op(features: OpFeatures):
528528
exir_ops.edge.aten.index_select.default,
529529
exir_ops.edge.aten.select_copy.int,
530530
# Tensor combination
531-
exir_ops.edge.aten.split_with_sizes_copy.default,
532-
exir_ops.edge.aten.split.Tensor,
533531
exir_ops.edge.aten.repeat.default,
534532
# Tensor creation
535533
exir_ops.edge.aten.arange.start_step,
@@ -563,6 +561,8 @@ def register_ported_op(features: OpFeatures):
563561
exir_ops.edge.aten.permute_copy.default,
564562
# Tensor combination
565563
exir_ops.edge.aten.cat.default,
564+
exir_ops.edge.aten.split_with_sizes_copy.default,
565+
exir_ops.edge.aten.split.Tensor,
566566
]
567567
)
568568
def register_ported_op_all_packed_dims(features: OpFeatures):

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ void add_split_with_sizes_default_node(
2525
ValueRef out_list_ref) {
2626
vTensorPtr t_in = graph.get_tensor(in);
2727

28-
VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim));
29-
3028
ValueListPtr out_list = graph.get_value_list(out_list_ref);
3129

3230
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
@@ -38,62 +36,60 @@ void add_split_with_sizes_default_node(
3836
ValueRef out_ref = (*out_list)[split_idx];
3937

4038
vTensorPtr t_out = graph.get_tensor(out_ref);
41-
VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim));
4239
VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size);
4340
}
4441

45-
if (dim_index == kWidth4D) {
46-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
47-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
42+
const auto packed_dim = t_in->packed_dim();
43+
const auto packed_dim_index = static_cast<DimIndex>(kWidth4D - packed_dim);
4844

49-
for (ValueRef out_ref : *out_list) {
50-
// Doesn't need to use split_size since we have already verified that the
51-
// output tensor's size matches with the split_size.
52-
vTensorPtr t_out = graph.get_tensor(out_ref);
53-
utils::ivec3 range = t_out->logical_limits();
54-
add_copy_offset_node(
55-
graph, in, range, src_offset, dst_offset, out_ref, false, true);
45+
// Index of dimension to be concatenated in (w, h, c * b) coordinate system
46+
const auto dim_xyz_index = std::min(2, -dim_index - 1);
5647

57-
src_offset[0] += range[0];
58-
}
59-
} else if (dim_index == kHeight4D) {
60-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
61-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
48+
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
49+
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
6250

63-
for (ValueRef out_ref : *out_list) {
64-
vTensorPtr t_out = graph.get_tensor(out_ref);
65-
utils::ivec3 range = t_out->logical_limits();
66-
add_copy_offset_node(
67-
graph, in, range, src_offset, dst_offset, out_ref, false, true);
51+
const bool is_splitting_channel = (dim_index == kChannel4D);
6852

69-
src_offset[1] += range[1];
70-
}
71-
} else if (dim_index == kBatch4D) {
72-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
73-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
53+
// if splitting channels
54+
if (is_splitting_channel) {
55+
// set source offset w as channel size of the input tensor
56+
src_offset[3] = dim_at(t_in->sizes(), kChannel4D);
57+
}
7458

75-
for (ValueRef out_ref : *out_list) {
76-
vTensorPtr t_out = graph.get_tensor(out_ref);
77-
utils::ivec3 range = t_out->logical_limits();
59+
for (ValueRef out_ref : *out_list) {
60+
// Doesn't need to use split_size since we have already verified that the
61+
// output tensor's size matches with the split_size.
62+
vTensorPtr t_out = graph.get_tensor(out_ref);
63+
const auto out_channel_size = dim_at(t_out->sizes(), kChannel4D);
64+
utils::ivec3 range = t_out->logical_limits();
65+
66+
if (dim_index == packed_dim_index) {
67+
// if splitting channels, use add_copy_channel_offset_node function as
68+
// add_copy_packed_dim_offset_node does not support channel packing
69+
if (is_splitting_channel) {
70+
add_copy_channel_offset_node(
71+
graph, in, out_channel_size, src_offset[2], dst_offset[2], out_ref);
72+
src_offset[dim_xyz_index] += out_channel_size;
73+
} else {
74+
// dst_offset[3] is not used now but will be used in the future when
75+
// add_copy_packed_dim_offset_node will support channel packing
76+
//
77+
// set destination offset w as channel size of the output tensor if
78+
// splitting channel
79+
dst_offset[3] = is_splitting_channel ? out_channel_size : 0;
80+
add_copy_packed_dim_offset_node(
81+
graph, in, range, src_offset, dst_offset, out_ref);
82+
src_offset[dim_xyz_index] += dim_at(t_out->sizes(), packed_dim_index);
83+
}
84+
} else {
85+
// set destination offset w as channel size of the output tensor if
86+
// splitting channels
87+
dst_offset[3] = is_splitting_channel ? out_channel_size : 0;
7888
add_copy_offset_node(
7989
graph, in, range, src_offset, dst_offset, out_ref, false, true);
80-
81-
src_offset[2] += range[2];
82-
}
83-
} else if (dim_index == kChannel4D) {
84-
int32_t src_offset = 0;
85-
int32_t dst_offset = 0;
86-
87-
for (ValueRef out_ref : *out_list) {
88-
vTensorPtr t_out = graph.get_tensor(out_ref);
89-
int32_t range = dim_at<kChannel4D>(t_out->sizes());
90-
add_copy_channel_offset_node(
91-
graph, in, range, src_offset, dst_offset, out_ref);
92-
src_offset += range;
90+
src_offset[dim_xyz_index] +=
91+
is_splitting_channel ? out_channel_size : range[dim_xyz_index];
9392
}
94-
95-
} else {
96-
VK_THROW("not ipmlemented");
9793
}
9894
}
9995

backends/vulkan/test/op_tests/cases.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,30 +922,41 @@ def get_split_with_sizes_inputs():
922922
Test = namedtuple("VkSliceTest", ["self", "sizes", "dim"])
923923
test_cases = [
924924
# Split on Width
925+
Test(self=(S1, 7, 10, 11), sizes=[1, 3, 2, 5], dim=3),
925926
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=3),
927+
Test(self=(7, 10, 11), sizes=[1, 3, 2, 5], dim=2),
926928
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
929+
Test(self=(7, 10, 11), sizes=[3, 8], dim=2),
927930
Test(self=(7, 10, 10), sizes=[1, 9], dim=2),
928931
Test(self=(10, 10), sizes=[1, 9], dim=1),
929932
Test(self=(10,), sizes=[1, 9], dim=0),
930933
# Split on Height
934+
Test(self=(S1, 7, 11, 10), sizes=[1, 3, 2, 5], dim=2),
931935
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
936+
Test(self=(7, 11, 10), sizes=[1, 3, 2, 5], dim=1),
932937
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=1),
938+
Test(self=(7, 11, 11), sizes=[3, 8], dim=1),
933939
Test(self=(7, 10, 10), sizes=[10], dim=1),
934940
Test(self=(7, 6, 10), sizes=[1, 1, 1, 1, 1, 1], dim=1),
935941
Test(self=(10, 10), sizes=[1, 2, 3, 4], dim=0),
936942
# Split on Batch
937943
Test(self=(10, 7, 10, 10), sizes=[3, 6, 1], dim=0),
938944
Test(self=(10, 7, 10, 10), sizes=[10], dim=0),
939945
# Split on Channel
946+
Test(self=(7, 13, 4, 8), sizes=[3, 5, 2, 3], dim=1),
940947
Test(self=(7, 13, 4, 8), sizes=[3, 6, 1, 3], dim=1),
948+
Test(self=(7, 13, 4, 8), sizes=[3, 2, 2, 5, 1], dim=1),
941949
Test(self=(7, 13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=1),
950+
Test(self=(13, 4, 8), sizes=[3, 5, 2, 1, 2], dim=0),
942951
Test(self=(13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=0),
943952
Test(self=(13, 4, 8), sizes=[2, 9, 2], dim=0),
944953
Test(self=(13, 4, 8), sizes=[13], dim=0),
945954
]
946955
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
947956

948957
test_suite.layouts = [
958+
"utils::kWidthPacked",
959+
"utils::kHeightPacked",
949960
"utils::kChannelsPacked",
950961
]
951962
test_suite.data_gen = "make_seq_tensor"
@@ -997,6 +1008,8 @@ def get_split_tensor_inputs():
9971008
)
9981009

9991010
test_suite.layouts = [
1011+
"utils::kWidthPacked",
1012+
"utils::kHeightPacked",
10001013
"utils::kChannelsPacked",
10011014
]
10021015
test_suite.data_gen = "make_seq_tensor"

0 commit comments

Comments
 (0)