Skip to content

Commit bb22d4a

Browse files
Add 2D-tiling matmul
1 parent fb56cf0 commit bb22d4a

File tree

1 file changed

+133
-2
lines changed

1 file changed

+133
-2
lines changed

examples/matmul/run.cpp

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,120 @@ fn main(
233233
}
234234
)";
235235

236+
/* 2D block-tiling
237+
*
238+
*/
239+
static const char *kShaderMatmul4 = R"(
240+
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
241+
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
242+
@group(0) @binding(2) var<storage, read_write> c: array<{{precision}}>;
243+
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
244+
var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
245+
246+
@compute @workgroup_size({{workgroupSize}})
247+
fn main(
248+
@builtin(global_invocation_id) globalID : vec3<u32>,
249+
@builtin(local_invocation_id) localID : vec3<u32>,
250+
@builtin(workgroup_id) groupid : vec3<u32>) {
251+
252+
var threadResults: array<{{precision}}, {{TM}} * {{TN}}>;
253+
var localM: array<{{precision}}, {{TM}}>;
254+
var localN: array<{{precision}}, {{TN}}>;
255+
256+
let cRow: u32 = groupid.x;
257+
let cCol: u32 = groupid.y;
258+
let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
259+
260+
// position of the first c element computed by the thread
261+
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
262+
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
263+
264+
let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
265+
let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
266+
267+
// aPtr and bPtr are the starting positions of the tiles in a and b,
268+
// incremented in the bkidx loop.
269+
// cPtr is the starting position of the tile in c which is fixed.
270+
271+
var aPtr = cRow * {{BM}} * {{K}};
272+
var bPtr = cCol * {{BN}} * {{K}};
273+
let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
274+
275+
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
276+
277+
// Load tile
278+
// Load BM x BK by numThread(BM * BN / (TM * TN))
279+
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
280+
for (var i: u32 = 0; i < numIterA; i++) {
281+
let loadColA: u32 = (localID.x + i * numThread) % {{BK}};
282+
let loadRowA: u32 = (localID.x + i * numThread) / {{BK}};
283+
tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA];
284+
}
285+
// Load BK x BN by numThread(BM * BN / (TM * TN))
286+
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
287+
for (var i: u32 = 0; i < numIterB; i++) {
288+
let loadColB: u32 = (localID.x + i * numThread) % {{BK}};
289+
let loadRowB: u32 = (localID.x + i * numThread) / {{BK}};
290+
tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB];
291+
}
292+
293+
aPtr += {{BK}};
294+
bPtr += {{BK}};
295+
296+
workgroupBarrier();
297+
// Compute tile
298+
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
299+
for (var i: u32 = 0; i < {{TM}}; i++) {
300+
localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx];
301+
}
302+
for (var i: u32 = 0; i < {{TN}}; i++) {
303+
localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx];
304+
}
305+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
306+
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
307+
threadResults[resIdxM * {{TN}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
308+
}
309+
}
310+
}
311+
workgroupBarrier();
312+
}
313+
314+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
315+
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
316+
c[cPtr + (threadRow + resIdxM) * {{N}} + threadCol + resIdxN] = threadResults[resIdxM * {{TN}} + resIdxN];
317+
}
318+
}
319+
}
320+
)";
321+
322+
inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M,
323+
const size_t K, const size_t N, const size_t BM,
324+
const size_t BK, const size_t BN,
325+
const size_t TM, const size_t TN,
326+
const Shape &workgroupSize = {256, 1, 1},
327+
NumType precision = kf32) {
328+
assert(BM % TM == 0);
329+
assert(BN % TN == 0);
330+
assert(K % BK == 0);
331+
assert(M % BM == 0);
332+
assert(N % BN == 0);
333+
// # threads = tile A size == tile B size == # threads for computing C
334+
//assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN);
335+
//assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN));
336+
std::string codeString(shaderTemplate);
337+
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
338+
{"{{precision}}", toString(precision)},
339+
{"{{M}}", toString(M)},
340+
{"{{K}}", toString(K)},
341+
{"{{N}}", toString(N)},
342+
{"{{BM}}", toString(BM)},
343+
{"{{BK}}", toString(BK)},
344+
{"{{BN}}", toString(BN)},
345+
{"{{TM}}", toString(TM)},
346+
{"{{TN}}", toString(TN)}});
347+
return ShaderCode{codeString, workgroupSize};
348+
}
349+
236350
inline ShaderCode createNoOp(const char *shaderTemplate,
237351
const Shape &workgroupSize = {256, 1, 1},
238352
NumType precision = kf32) {
@@ -304,6 +418,22 @@ Kernel selectMatmul(Context &ctx, int version,
304418
kernel = createKernel(ctx, matmul, bindings,
305419
/*nWorkgroups*/ nWorkgroups);
306420
} else if (version == 4) {
421+
static constexpr size_t BM = 64;
422+
static constexpr size_t BK = 16;
423+
static constexpr size_t BN = 64;
424+
static constexpr size_t TM = BM / BK;
425+
static constexpr size_t TN = BN / BK;
426+
Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK.
427+
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
428+
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
429+
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN);
430+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
431+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
432+
ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
433+
/*wgSize*/ wgSize);
434+
kernel = createKernel(ctx, matmul, bindings,
435+
/*nWorkgroups*/ nWorkgroups);
436+
} else if (version == 5) {
307437
Shape wgSize = {256, 1, 1};
308438
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
309439
ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
@@ -371,10 +501,11 @@ void runTest(int version, size_t M, size_t K, size_t N,
371501
}
372502

373503
int main() {
374-
int version = 3; // 1 == naive matmul
504+
int version = 4; // 1 == naive matmul
375505
// 2 == tiling
376506
// 3 == 1D blocktiling
377-
// 4 == No-Op
507+
// 4 == 2D blocktiling
508+
// 5 == No-Op
378509
size_t M, K, N; // Matrix dimensions
379510
static constexpr int kTestSize = 2;
380511
if constexpr (kTestSize == 0) {

0 commit comments

Comments
 (0)