diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 228e2e8f870..dfb5f1f2f9c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} ${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} $if STORAGE == "buffer": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "int", "out_numel")} - ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(8, "ivec4", "mat1_strides")} - ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} - ${layout_declare_ubo(10, "ivec4", "scales_strides")} + layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 out_strides; + ivec4 mat1_sizes; + ivec4 mat1_strides; + ivec4 qmat2_strides; + ivec4 scales_strides; + int out_numel; + }; $else: - ${layout_declare_ubo(4, "ivec3", "out_limits")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} + layout(push_constant) uniform restrict Block { + ivec3 out_limits; + ivec4 mat1_sizes; + }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -83,42 +87,40 @@ void main() { #else // USING_TEXTURE -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - void main() { - const u16vec2 out_pos = u16vec2( - gl_GlobalInvocationID.x, - gl_GlobalInvocationID.y); + const ivec2 out_pos = ivec2( + gl_GlobalInvocationID.x % out_limits.x, + gl_GlobalInvocationID.x / out_limits.x); - if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) { + if (out_pos.y >= out_limits.y) { return; } - const uint16_t qmat2_pos_x = out_pos.x; + const int qmat2_pos_x = out_pos.x; VEC4_T outtex = VEC4_T(0); - const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0)); + const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0)); VEC4_T mat1_tex; VEC4_T mat2_tex[4]; for ( - uint16_t i = uint16_t(0), x = uint16_t(0); - i < uint16_t(mat1_sizes.x); - i += uint16_t(4), x++) + int i = 0, x = 0; + i < mat1_sizes.x; + i += 4, x++) { - mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); + mat1_tex = load_texel(t_mat1, ivec3(x, out_pos.y, 0)); - mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0)); - mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0)); - mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0)); - mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0)); + mat2_tex[0] = load_texel(t_qmat2, ivec3(out_pos.x, i, 0)); + mat2_tex[1] = load_texel(t_qmat2, ivec3(out_pos.x, i + 1, 0)); + mat2_tex[2] = load_texel(t_qmat2, ivec3(out_pos.x, i + 2, 0)); + mat2_tex[3] = load_texel(t_qmat2, ivec3(out_pos.x, i + 3, 0)); outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3]; } outtex *= scales; - write_texel(t_out, u16vec3(out_pos, 0), outtex); + write_texel(t_out, ivec3(out_pos, 0), outtex); } #endif diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 49085ff4e06..5054b2e5e9c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -98,47 +98,25 @@ void add_q_8w_linear_node( add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); - vkapi::ParamsBindList ubos({}); + std::vector pcs; if (graph.is_buffer_storage(out_W_packed)) { - ubos.append( - {graph.sizes_ubo(out_W_packed), - graph.strides_ubo(out_W_packed), - graph.numel_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed), - graph.strides_ubo(mat1), - graph.strides_ubo(q_mat2), - graph.strides_ubo(scales)}); + pcs = { + graph.sizes_pc_of(out_W_packed), + graph.strides_pc_of(out_W_packed), + graph.sizes_pc_of(mat1_W_packed), + graph.strides_pc_of(mat1), + graph.strides_pc_of(q_mat2), + graph.strides_pc_of(scales), + graph.numel_pc_of(out_W_packed)}; } else { - ubos.append( - {graph.logical_limits_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed)}); + pcs = { + graph.logical_limits_pc_of(out_W_packed), + graph.sizes_pc_of(mat1_W_packed)}; } - utils::uvec3 global_wg; - if (graph.is_buffer_storage(out)) { - global_wg = {static_cast(graph.numel_of(out_W_packed)), 1, 1}; - } else { - global_wg = graph.logical_limits_of(out_W_packed); - } - - utils::uvec3 local_wg{8, 8, 1}; - int32_t out_W = graph.size_at(-1, out_W_packed); - - if (graph.is_buffer_storage(out_W_packed)) { - local_wg[0] = 64; - local_wg[1] = 1; - local_wg[2] = 1; - } else { - if (out_W % 8 != 0) { - if (out_W % 4 == 0) { - local_wg[0] = 4; - local_wg[1] = 16; - } else { - local_wg[0] = 2; - local_wg[1] = 32; - } - } - } + const utils::uvec3 global_wg = { + static_cast(graph.numel_of(out_W_packed)), 1, 1}; + const utils::uvec3 local_wg{64, 1, 1}; graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -149,11 +127,13 @@ void add_q_8w_linear_node( {{out_W_packed, vkapi::MemoryAccessType::WRITE}, {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, // Shader params buffers - ubos, + {}, // Specialization Constants {}, // Resizing Logic - resize_q_8w_linear_node)); + resize_q_8w_linear_node, + {}, + pcs)); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(out) != WHCN::kWidthDim) { viewFn(graph, {out_W_packed, graph.add_none(), out});