Skip to content

Commit 09e591d

Browse files
authored
[ET-VK] Adding boolean parameters to add_copy_offset_node to specify index calculation function in copy op's shader.
Differential Revision: D71343588 Pull Request resolved: #9343
1 parent b667d83 commit 09e591d

File tree

6 files changed

+52
-18
lines changed

6 files changed

+52
-18
lines changed

backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,29 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
3535
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
3636
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
3737

38+
${layout_declare_spec_const(C, "int", "batch_index_function", "0")}
39+
3840
void main() {
3941
const ivec3 pos = ivec3(gl_GlobalInvocationID);
4042

4143
if (any(greaterThanEqual(pos, range))) {
4244
return;
4345
}
4446

45-
const ivec3 in_pos = pos + src_offset.xyz;
47+
ivec3 in_pos = pos + src_offset.xyz;
4648
ivec3 out_pos = pos + dst_offset.xyz;
47-
48-
// If source channel size is specified compose output z based on channel and batch index
4949
if (src_offset.w > 0) {
50-
const int channel_index = in_pos.z % src_offset.w;
51-
const int batch_index = in_pos.z / src_offset.w;
52-
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
50+
if (batch_index_function == 1) {
51+
// batch index is calculated using source channel size
52+
const int channel_index = pos.z % src_offset.w;
53+
const int batch_index = pos.z / src_offset.w;
54+
out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w;
55+
} else if (batch_index_function == 2) {
56+
// batch index is calculated using destination channel size
57+
const int channel_index = pos.z % dst_offset.w;
58+
const int batch_index = pos.z / dst_offset.w;
59+
in_pos.z = channel_index + src_offset.z + batch_index * src_offset.w;
60+
}
5361
}
5462

5563
write_texel_lpos(

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void add_cat_default_node(
8080
// concatenating channels
8181
src_offset[3] = is_concat_channel ? in_channel_size : 0;
8282
add_copy_offset_node(
83-
graph, input_ref, range, src_offset, dst_offset, out);
83+
graph, input_ref, range, src_offset, dst_offset, out, true, false);
8484
dst_offset[dim_xyz_index] +=
8585
is_concat_channel ? in_channel_size : range[dim_xyz_index];
8686
}

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ void add_copy_offset_node(
2525
const ivec3& range,
2626
const ivec4& src_offset,
2727
const ivec4& dst_offset,
28-
const ValueRef out) {
28+
const ValueRef out,
29+
bool calc_out_pos_using_src_chnl,
30+
bool calc_in_pos_using_dst_chnl) {
2931
vTensorPtr t_in = graph.get_tensor(in);
3032
vTensorPtr t_out = graph.get_tensor(out);
3133

@@ -49,7 +51,11 @@ void add_copy_offset_node(
4951
// Parameter buffers
5052
{},
5153
// Specialization Constants
52-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
54+
{graph.hashed_layout_of(out),
55+
graph.hashed_layout_of(in),
56+
(calc_out_pos_using_src_chnl ? 1
57+
: calc_in_pos_using_dst_chnl ? 2
58+
: 0)},
5359
nullptr,
5460
{},
5561
{
@@ -256,7 +262,8 @@ void add_copy_offset_node(
256262
ivec4 src_offset = {src[0], src[1], src[2], 0};
257263
ivec4 dst_offset = {dst[0], dst[1], dst[2], 0};
258264

259-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out);
265+
add_copy_offset_node(
266+
graph, in, range, src_offset, dst_offset, out, false, false);
260267
}
261268

262269
void copy_offset(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Copy.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,28 @@ namespace vkcompute {
2222
// It is possible to have input and output to point to the same image
2323
// object. But when the source range and destination range overlap, the behavior
2424
// is undefined.
25+
//
26+
// boolean flags calc_out_pos_using_src_chnl and calc_in_pos_using_dst_chnl
27+
// can be used to specify an indexing function in the shader
28+
// If calc_out_pos_using_src_chnl is set to true channel and batch index will be
29+
// calculated based on source channel size and will be used to determine
30+
// destination texel position.
31+
//
32+
// If calc_in_pos_using_dst_chnl is set to truechannel and batch index will be
33+
// calculated based on destination channel size and will be used to determine
34+
// source texel position.
35+
//
36+
// If both are true calc_out_pos_using_src_chnl is picked. If both are false no
37+
// index calculation happens.
2538
void add_copy_offset_node(
2639
ComputeGraph& graph,
2740
const ValueRef in,
2841
const utils::ivec3& range,
2942
const utils::ivec4& src_offset,
3043
const utils::ivec4& dst_offset,
31-
const ValueRef out);
44+
const ValueRef out,
45+
bool calc_out_pos_using_src_chnl,
46+
bool calc_in_pos_using_dst_chnl);
3247

3348
// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that
3449
// its used when copying packed dimension, if tensor is width or height packed.

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ void add_repeat_node(
151151
utils::ivec4 src_offset{0, 0, 0, 0};
152152
utils::ivec4 dst_offset{0, 0, 0, 0};
153153

154-
add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);
154+
add_copy_offset_node(
155+
graph, in, running_range, src_offset, dst_offset, out, false, false);
155156

156157
} else {
157158
add_repeat_channel_node(graph, in, channel_repeat, out, running_range);
@@ -166,7 +167,7 @@ void add_repeat_node(
166167
utils::ivec4 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0, 0};
167168

168169
add_copy_offset_node(
169-
graph, out, running_range, src_offset, dst_offset, out);
170+
graph, out, running_range, src_offset, dst_offset, out, true, false);
170171
}
171172

172173
running_range[0] = running_range[0] * width_repeat;
@@ -180,7 +181,7 @@ void add_repeat_node(
180181
utils::ivec4 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0, 0};
181182

182183
add_copy_offset_node(
183-
graph, out, running_range, src_offset, dst_offset, out);
184+
graph, out, running_range, src_offset, dst_offset, out, true, false);
184185
}
185186

186187
running_range[1] = running_range[1] * height_repeat;
@@ -194,7 +195,7 @@ void add_repeat_node(
194195
utils::ivec4 dst_offset = {0, 0, i * running_range[2], 0};
195196

196197
add_copy_offset_node(
197-
graph, out, running_range, src_offset, dst_offset, out);
198+
graph, out, running_range, src_offset, dst_offset, out, true, false);
198199
}
199200

200201
running_range[2] = running_range[2] * batch_repeat;

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void add_split_with_sizes_default_node(
5151
// output tensor's size matches with the split_size.
5252
vTensorPtr t_out = graph.get_tensor(out_ref);
5353
utils::ivec3 range = t_out->logical_limits();
54-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
54+
add_copy_offset_node(
55+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
5556

5657
src_offset[0] += range[0];
5758
}
@@ -62,7 +63,8 @@ void add_split_with_sizes_default_node(
6263
for (ValueRef out_ref : *out_list) {
6364
vTensorPtr t_out = graph.get_tensor(out_ref);
6465
utils::ivec3 range = t_out->logical_limits();
65-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
66+
add_copy_offset_node(
67+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
6668

6769
src_offset[1] += range[1];
6870
}
@@ -73,7 +75,8 @@ void add_split_with_sizes_default_node(
7375
for (ValueRef out_ref : *out_list) {
7476
vTensorPtr t_out = graph.get_tensor(out_ref);
7577
utils::ivec3 range = t_out->logical_limits();
76-
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
78+
add_copy_offset_node(
79+
graph, in, range, src_offset, dst_offset, out_ref, false, true);
7780

7881
src_offset[2] += range[2];
7982
}

0 commit comments

Comments
 (0)