@@ -105,6 +105,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
105
105
}
106
106
107
107
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
108
+ shared vec4 tmpshv4[gl_WorkGroupSize.x];
108
109
109
110
shared float16_t masksh[Bc][Br];
110
111
shared vec4 Qf[Br][D / 4];
@@ -168,13 +169,15 @@ void main() {
168
169
169
170
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
170
171
171
- [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
172
- if (i * Br + r < N) {
173
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
174
- Qf[r][d * D_split + d_tid] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
175
- }
172
+ [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
173
+ uint32_t d = (idx + tid) % (D / 4);
174
+ uint32_t r = (idx + tid) / (D / 4);
175
+ if (r < Br && d < D / 4 &&
176
+ i * Br + r < N) {
177
+ Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
176
178
}
177
179
}
180
+ barrier();
178
181
179
182
vec4 Of[Br][D_per_thread / 4];
180
183
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
@@ -350,20 +353,18 @@ void main() {
350
353
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
351
354
352
355
Of[r][d] = eMf * Of[r][d];
353
- [[unroll]] for (uint32_t c = 0; c < 4; ++c) {
354
- tmpsh[tid] = Of[r][d][c];
356
+ tmpshv4[tid] = Of[r][d];
355
357
356
- barrier();
357
- [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
358
- if (tid < s) {
359
- Of[r][d][c] += tmpsh[tid + s];
360
- tmpsh[tid] = Of[r][d][c];
361
- }
362
- barrier();
358
+ barrier();
359
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
360
+ if (tid < s) {
361
+ Of[r][d] += tmpshv4[tid + s];
362
+ tmpshv4[tid] = Of[r][d];
363
363
}
364
- Of[r][d][c] = tmpsh[d_tid];
365
364
barrier();
366
365
}
366
+ Of[r][d] = tmpshv4[d_tid];
367
+ barrier();
367
368
}
368
369
}
369
370
0 commit comments