diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index af7c22bb5ad..2393ed33450 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -34,6 +34,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +// shared memory to hold calculated positions, this would reduce register usage thus improving performance. +shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE]; + /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel @@ -41,6 +44,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE); + const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; const u16vec3 gpos = u16vec3( gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z), @@ -58,6 +62,7 @@ void main() { for (int x = 0; x < TILE_SIZE; ++x) { pos[i] = u16vec2( gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y); + pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i]; i++; } } @@ -73,7 +78,7 @@ void main() { // the top-left element is in a region added by padding. u16vec2 ipos[TILE_SIZE * TILE_SIZE]; for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding); + ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding); } vec4 sum[TILE_SIZE * TILE_SIZE]; @@ -138,8 +143,9 @@ void main() { } for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - if (all(lessThan(u16vec3(pos[i], gpos.z), out_limits))) { - imageStore(t_out, u16vec3(pos[i], gpos.z), op(sum[i], out_min, out_max)); + const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex]; + if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) { + imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max)); } } }