From b687df038ee8688cfaa1de665afa267cdf980ff5 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 11 Apr 2025 12:51:39 -0700 Subject: [PATCH] [ET-VK] Minor improvement to permute op. This change reduces the complexity of boundary comparison in permute op to improve speed. Differential Revision: [D72866962](https://our.internmc.facebook.com/intern/diff/D72866962/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/permute.glsl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 8a8703becd9..716c42e8ede 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -31,6 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; +#extension GL_EXT_control_flow_attributes : require + void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -54,11 +56,16 @@ void main() { in_bchw_pos[out_ndims[2]] = pos.y; in_bchw_pos[out_ndims[3]] = pos.x; - for (int j = 0; j < 4; ++j) { + const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]]; + + [[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) { // terminate the loop if trying to access input texture out of bounds - if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) { + if (bchw_index >= in_packed_dim_size) { break; } + // go to position in the input, that is mapped to the packed dim in the output + in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index; + ivec3 fetch_pos; fetch_pos.xy = in_bchw_pos.wz; @@ -74,9 +81,6 @@ void main() { // fetch input texel VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); outval[j] = inval[in_packed_dim_lane_index]; - - // go to next position in the input, that is mapped to the packed dim in the output - in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++; } pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);