Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/developer-guide/glsl-extension.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ afp lfp2afp(lfp v);
afpvec4 lfp2afpvec4(lfpvec4 v);
```

- local variable to local memory

```c
lfp afp2lfp(afp v);
lfpvec4 afp2lfpvec4(afpvec4 v);
```

Note: The common usage of local memory is to read from global memory first, store it in local memory, and then read local variables from local memory for subsequent use. Therefore, only storage type to local type and local type to arithmetic type conversion functions are provided here.

# misc functions
Expand Down
7 changes: 7 additions & 0 deletions docs/developer-guide/glsl-extension.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ afp lfp2afp(lfp v);
afpvec4 lfp2afpvec4(lfpvec4 v);
```

- 局部变量转换到本地内存

```c
lfp afp2lfp(afp v);
lfpvec4 afp2lfpvec4(afpvec4 v);
```

注意:本地内存的常见用法是先从全局内存中读取,存储在本地内存中,然后再从本地内存中读取局部变量以供后续使用。因此,此处仅提供存储类型到本地类型和本地类型到算术类型的转换函数。

# 杂项函数
Expand Down
19 changes: 18 additions & 1 deletion src/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4903,8 +4903,10 @@ int compile_spirv_module(const char* comp_data, int comp_data_size, const Option
custom_defines.append("buffer_sm1(buf,i)", "buf[i]");
custom_defines.append("buffer_sm4(buf,i)", "buf[i]");

custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("lfp2afp(v)", "float(v)");
custom_defines.append("afp2lfp(v)", "bfloat16_t(v)");
custom_defines.append("lfp2afpvec4(v)", "vec4(v)");
custom_defines.append("afp2lfpvec4(v)", "bf16vec4(v)");
}
else if (opt.use_bf16_packed)
{
Expand All @@ -4925,60 +4927,75 @@ int compile_spirv_module(const char* comp_data, int comp_data_size, const Option
if (support_fp16_uniform)
{
custom_defines.append("lfp2afp(v)", "uintBitsToFloat(uint(v)<<16)");
custom_defines.append("afp2lfp(v)", "uint16_t(floatBitsToUint(v)>>16)");
}
else
{
custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("afp2lfp(v)", "v");
}
custom_defines.append("lfp2afpvec4(v)", "vec4(unpackBFloat2x16(v.x),unpackBFloat2x16(v.y))");
custom_defines.append("afp2lfpvec4(v)", "uvec2(packBFloat2x16(v.rg),packBFloat2x16(v.ba))");
}
else if (opt.use_fp16_storage && opt.use_fp16_uniform && opt.use_fp16_arithmetic)
{
custom_defines.append("buffer_sm1(buf,i)", "buf[i]");
custom_defines.append("buffer_sm4(buf,i)", "buf[i]");

custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("afp2lfp(v)", "v");
custom_defines.append("lfp2afpvec4(v)", "v");
custom_defines.append("afp2lfpvec4(v)", "v");
}
else if (opt.use_fp16_storage && opt.use_fp16_arithmetic)
{
custom_defines.append("buffer_sm1(buf,i)", "float(buf[i])");
custom_defines.append("buffer_sm4(buf,i)", "pack64(halfBitsToUint16(buf[i]))");

custom_defines.append("lfp2afp(v)", "float16_t(v)");
custom_defines.append("afp2lfp(v)", "float(v)");
custom_defines.append("lfp2afpvec4(v)", "uint16BitsToHalf(unpack16(v))");
custom_defines.append("afp2lfpvec4(v)", "pack64(halfBitsToUint16(v))");
}
else if (opt.use_fp16_packed && opt.use_fp16_arithmetic)
{
custom_defines.append("buffer_sm1(buf,i)", "unpackHalf2x16(buf[(i)/2])[(i)%2]");
custom_defines.append("buffer_sm4(buf,i)", "buf[i]");

custom_defines.append("lfp2afp(v)", "float16_t(v)");
custom_defines.append("afp2lfp(v)", "float(v)");
custom_defines.append("lfp2afpvec4(v)", "f16vec4(unpackFloat2x16(v.x),unpackFloat2x16(v.y))");
custom_defines.append("afp2lfpvec4(v)", "uvec2(packFloat2x16(v.rg),packFloat2x16(v.ba))");
}
else if (opt.use_fp16_storage)
{
custom_defines.append("buffer_sm1(buf,i)", "float(buf[i])");
custom_defines.append("buffer_sm4(buf,i)", "uvec2(packHalf2x16(vec4(buf[i]).rg),packHalf2x16(vec4(buf[i]).ba))");

custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("afp2lfp(v)", "float(v)");
custom_defines.append("lfp2afpvec4(v)", "vec4(unpackHalf2x16(v.x),unpackHalf2x16(v.y))");
custom_defines.append("afp2lfpvec4(v)", "uvec2(packHalf2x16(v.rg),packHalf2x16(v.ba))");
}
else if (opt.use_fp16_packed)
{
custom_defines.append("buffer_sm1(buf,i)", "unpackHalf2x16(buf[(i)/2])[(i)%2]");
custom_defines.append("buffer_sm4(buf,i)", "buf[i]");

custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("afp2lfp(v)", "v");
custom_defines.append("lfp2afpvec4(v)", "vec4(unpackHalf2x16(v.x),unpackHalf2x16(v.y))");
custom_defines.append("afp2lfpvec4(v)", "uvec2(packHalf2x16(v.rg),packHalf2x16(v.ba))");
}
else
{
custom_defines.append("buffer_sm1(buf,i)", "buf[i]");
custom_defines.append("buffer_sm4(buf,i)", "buf[i]");

custom_defines.append("lfp2afp(v)", "v");
custom_defines.append("afp2lfp(v)", "v");
custom_defines.append("lfp2afpvec4(v)", "v");
custom_defines.append("afp2lfpvec4(v)", "v");
}

if (opt.use_bf16_storage)
Expand Down
150 changes: 127 additions & 23 deletions src/layer/vulkan/gemm_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Gemm_vulkan::Gemm_vulkan()

pipeline_gemm = 0;

use_subgroup_ops = false;

use_cooperative_matrix = false;
coopmat_M = 0;
coopmat_N = 0;
Expand Down Expand Up @@ -81,6 +83,14 @@ int Gemm_vulkan::create_pipeline(const Option& opt)
use_bf16_cooperative_matrix = true;
}

const int subgroup_size = vkdev->info.subgroup_size();
use_subgroup_ops = opt.use_subgroup_ops && (vkdev->info.support_subgroup_ops() & (VK_SUBGROUP_FEATURE_BASIC_BIT | VK_SUBGROUP_FEATURE_SHUFFLE_BIT));
if (subgroup_size < 4 || subgroup_size > 128)
{
// sanitize wired subgroup_size
use_subgroup_ops = false;
}

if (use_cooperative_matrix)
{
int M = constantM ? constantM : 1024;
Expand Down Expand Up @@ -137,7 +147,7 @@ int Gemm_vulkan::create_pipeline(const Option& opt)
pipeline_gemm->set_local_size_xyz(coopmat_subgroup_size * UNROLL_WG_M * UNROLL_WG_N, 1, 1);
pipeline_gemm->create(LayerShaderType::gemm_cm, opt, specializations);
}
else
else if (opt.use_shader_local_memory)
{
std::vector<vk_specialization_type> specializations(15);
specializations[0].f = alpha;
Expand All @@ -156,25 +166,98 @@ int Gemm_vulkan::create_pipeline(const Option& opt)
specializations[13].i = output_elemtype;
specializations[14].i = output_transpose;

Mat local_size_xyz;
// if (shape_packed.dims == 2)
// {
// local_size_xyz.w = std::min(8, shape_packed.w);
// local_size_xyz.h = std::min(8, shape_packed.h);
// local_size_xyz.c = 1;
// }

// pack1
// if (shape.dims == 0 || elempack == 1)
pipeline_gemm = new Pipeline(vkdev);
pipeline_gemm->set_local_size_xyz(8, 8, 1);
pipeline_gemm->create(LayerShaderType::gemm, opt, specializations);
}
else if (use_subgroup_ops)
{
if (subgroup_size == 128)
{
pipeline_gemm = new Pipeline(vkdev);
pipeline_gemm->set_optimal_local_size_xyz(local_size_xyz);
if (opt.use_shader_local_memory)
{
pipeline_gemm->set_local_size_xyz(8, 8, 1);
}
pipeline_gemm->create(LayerShaderType::gemm, opt, specializations);
UNROLL_SG_M = 16;
UNROLL_SG_N = 8;
UNROLL_SG_K = 8;
}
if (subgroup_size == 64)
{
UNROLL_SG_M = 8;
UNROLL_SG_N = 8;
UNROLL_SG_K = 8;
}
if (subgroup_size == 32)
{
UNROLL_SG_M = 8;
UNROLL_SG_N = 4;
UNROLL_SG_K = 4;
}
if (subgroup_size == 16)
{
UNROLL_SG_M = 4;
UNROLL_SG_N = 4;
UNROLL_SG_K = 4;
}
if (subgroup_size == 8)
{
UNROLL_SG_M = 4;
UNROLL_SG_N = 2;
UNROLL_SG_K = 2;
}
if (subgroup_size == 4)
{
UNROLL_SG_M = 2;
UNROLL_SG_N = 2;
UNROLL_SG_K = 2;
}

std::vector<vk_specialization_type> specializations(18);
specializations[0].f = alpha;
specializations[1].f = beta;
specializations[2].i = transA;
specializations[3].i = transB;
specializations[4].i = constantA;
specializations[5].i = constantB;
specializations[6].i = constantC;
specializations[7].u32 = constantM;
specializations[8].u32 = constantN;
specializations[9].u32 = constantK;
specializations[10].i = constant_broadcast_type_C;
specializations[11].i = output_N1M;
specializations[12].i = output_elempack;
specializations[13].i = output_elemtype;
specializations[14].i = output_transpose;
specializations[15].u32 = UNROLL_SG_M;
specializations[16].u32 = UNROLL_SG_N;
specializations[17].u32 = UNROLL_SG_K;

pipeline_gemm = new Pipeline(vkdev);
pipeline_gemm->set_subgroup_size(subgroup_size);
pipeline_gemm->set_local_size_xyz(subgroup_size, 1, 1);
pipeline_gemm->create(LayerShaderType::gemm_sg, opt, specializations);
}
else
{
std::vector<vk_specialization_type> specializations(15);
specializations[0].f = alpha;
specializations[1].f = beta;
specializations[2].i = transA;
specializations[3].i = transB;
specializations[4].i = constantA;
specializations[5].i = constantB;
specializations[6].i = constantC;
specializations[7].i = constantM;
specializations[8].i = constantN;
specializations[9].i = constantK;
specializations[10].i = constant_broadcast_type_C;
specializations[11].i = output_N1M;
specializations[12].i = output_elempack;
specializations[13].i = output_elemtype;
specializations[14].i = output_transpose;

Mat local_size_xyz;

pipeline_gemm = new Pipeline(vkdev);
pipeline_gemm->set_optimal_local_size_xyz(local_size_xyz);
pipeline_gemm->create(LayerShaderType::gemm, opt, specializations);
}

if (opt.lightmode)
Expand All @@ -192,6 +275,8 @@ int Gemm_vulkan::destroy_pipeline(const Option& /*opt*/)
delete pipeline_gemm;
pipeline_gemm = 0;

use_subgroup_ops = false;

use_cooperative_matrix = false;
coopmat_M = 0;
coopmat_N = 0;
Expand Down Expand Up @@ -362,15 +447,34 @@ int Gemm_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM

cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher);
}
else
else if (opt.use_shader_local_memory)
{
VkMat dispatcher;
dispatcher.w = (N + 3) / 4;
dispatcher.h = (M + 3) / 4;
dispatcher.c = 1;
cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher);
}
else if (use_subgroup_ops)
{
const Pipeline* pipeline = pipeline_gemm;
const int subgroup_size = vkdev->info.subgroup_size();

const int blocks_x = (M + (UNROLL_SG_M * 4 - 1)) / (UNROLL_SG_M * 4);
const int blocks_y = (N + (UNROLL_SG_N * 4 - 1)) / (UNROLL_SG_N * 4);

VkMat dispatcher;
dispatcher.w = (N + 1) / 2;
dispatcher.h = (M + 1) / 2;
dispatcher.w = (blocks_x * blocks_y) * subgroup_size;
dispatcher.h = 1;
dispatcher.c = 1;
cmd.record_pipeline(pipeline, bindings, constants, dispatcher);
cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher);
}
else
{
VkMat dispatcher;
dispatcher.w = (N + 3) / 4;
dispatcher.h = (M + 3) / 4;
dispatcher.c = 1;
cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher);
}

int out_elempack = 1;
Expand Down
3 changes: 3 additions & 0 deletions src/layer/vulkan/gemm_vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Gemm_vulkan : public Gemm

Pipeline* pipeline_gemm;

// subgroup
bool use_subgroup_ops;

// cooperative matrix
bool use_cooperative_matrix;
int coopmat_M;
Expand Down
8 changes: 4 additions & 4 deletions src/layer/vulkan/sdpa_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
else
{
VkMat dispatcher;
dispatcher.w = (N + 1) / 2;
dispatcher.h = (M + 1) / 2;
dispatcher.w = (N + 3) / 4;
dispatcher.h = (M + 3) / 4;
dispatcher.c = B;

cmd.record_pipeline(pipeline_sdpa_qk_cross, bindings, constants, dispatcher);
Expand Down Expand Up @@ -415,8 +415,8 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
else
{
VkMat dispatcher;
dispatcher.w = (N + 1) / 2;
dispatcher.h = (M + 1) / 2;
dispatcher.w = (N + 3) / 4;
dispatcher.h = (M + 3) / 4;
dispatcher.c = B;

cmd.record_pipeline(pipeline_sdpa_qkv_cross, bindings, constants, dispatcher);
Expand Down
Loading
Loading