Skip to content

Commit 989bfb1

Browse files
committed
vulkan: load each Q value once. optimize O reduction. more tuning
1 parent c747227 commit 989bfb1

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,7 +1590,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
15901590

15911591
// number of rows/cols for flash attention shader
15921592
static constexpr uint32_t flash_attention_num_small_rows = 32;
1593-
static constexpr uint32_t scalar_flash_attention_num_small_rows = 4;
1593+
static constexpr uint32_t scalar_flash_attention_num_small_rows = 8;
15941594

15951595
static uint32_t get_fa_num_small_rows(bool scalar) {
15961596
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
@@ -1601,7 +1601,7 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
16011601

16021602
// small rows, large cols
16031603
if (small_rows || scalar) {
1604-
return {get_fa_num_small_rows(scalar), 64};
1604+
return {get_fa_num_small_rows(scalar), 32};
16051605
}
16061606

16071607
// small cols to reduce register count
@@ -1913,7 +1913,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
19131913
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
19141914
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
19151915
const uint32_t D_lsb = D ^ (D & (D-1));
1916-
uint32_t D_split = std::min(std::min(device->subgroup_size, 16u), D_lsb / 4);
1916+
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
19171917

19181918
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
19191919
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
105105
}
106106

107107
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
108+
shared vec4 tmpshv4[gl_WorkGroupSize.x];
108109

109110
shared float16_t masksh[Bc][Br];
110111
shared vec4 Qf[Br][D / 4];
@@ -168,13 +169,15 @@ void main() {
168169

169170
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
170171

171-
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
172-
if (i * Br + r < N) {
173-
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
174-
Qf[r][d * D_split + d_tid] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
175-
}
172+
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
173+
uint32_t d = (idx + tid) % (D / 4);
174+
uint32_t r = (idx + tid) / (D / 4);
175+
if (r < Br && d < D / 4 &&
176+
i * Br + r < N) {
177+
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
176178
}
177179
}
180+
barrier();
178181

179182
vec4 Of[Br][D_per_thread / 4];
180183
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
@@ -350,20 +353,18 @@ void main() {
350353
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
351354

352355
Of[r][d] = eMf * Of[r][d];
353-
[[unroll]] for (uint32_t c = 0; c < 4; ++c) {
354-
tmpsh[tid] = Of[r][d][c];
356+
tmpshv4[tid] = Of[r][d];
355357

356-
barrier();
357-
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
358-
if (tid < s) {
359-
Of[r][d][c] += tmpsh[tid + s];
360-
tmpsh[tid] = Of[r][d][c];
361-
}
362-
barrier();
358+
barrier();
359+
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
360+
if (tid < s) {
361+
Of[r][d] += tmpshv4[tid + s];
362+
tmpshv4[tid] = Of[r][d];
363363
}
364-
Of[r][d][c] = tmpsh[d_tid];
365364
barrier();
366365
}
366+
Of[r][d] = tmpshv4[d_tid];
367+
barrier();
367368
}
368369
}
369370

0 commit comments

Comments
 (0)