From 2a1a75aca5c9a008cdfecdf1a85bb01c7039e58a Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 2 Jan 2025 12:16:25 -0800 Subject: [PATCH] [ET-VK] Using shared variable to store calculated output pose to free up registers and improve performance. This diff introduces a shared variable to store calculated output pose in conv2d_pw op to free up registers and improve performance. The code changes include adding a shared variable to hold calculated positions and modifying the existing code to use the shared variable. Differential Revision: [D67742567](https://our.internmc.facebook.com/intern/diff/D67742567/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index af7c22bb5ad..2393ed33450 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -34,6 +34,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +// shared memory to hold calculated positions, this would reduce register usage thus improving performance. +shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE]; + /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel @@ -41,6 +44,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE); + const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; const u16vec3 gpos = u16vec3( gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z), @@ -58,6 +62,7 @@ void main() { for (int x = 0; x < TILE_SIZE; ++x) { pos[i] = u16vec2( gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y); + pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i]; i++; } } @@ -73,7 +78,7 @@ void main() { // the top-left element is in a region added by padding. u16vec2 ipos[TILE_SIZE * TILE_SIZE]; for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding); + ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding); } vec4 sum[TILE_SIZE * TILE_SIZE]; @@ -138,8 +143,9 @@ void main() { } for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - if (all(lessThan(u16vec3(pos[i], gpos.z), out_limits))) { - imageStore(t_out, u16vec3(pos[i], gpos.z), op(sum[i], out_min, out_max)); + const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex]; + if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) { + imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max)); } } }