@@ -233,6 +233,120 @@ fn main(
233
233
}
234
234
)" ;
235
235
236
+ /* 2D block-tiling
237
+ *
238
+ */
239
+ static const char *kShaderMatmul4 = R"(
240
+ @group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
241
+ @group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
242
+ @group(0) @binding(2) var<storage, read_write> c: array<{{precision}}>;
243
+ var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
244
+ var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
245
+
246
+ @compute @workgroup_size({{workgroupSize}})
247
+ fn main(
248
+ @builtin(global_invocation_id) globalID : vec3<u32>,
249
+ @builtin(local_invocation_id) localID : vec3<u32>,
250
+ @builtin(workgroup_id) groupid : vec3<u32>) {
251
+
252
+ var threadResults: array<{{precision}}, {{TM}} * {{TN}}>;
253
+ var localM: array<{{precision}}, {{TM}}>;
254
+ var localN: array<{{precision}}, {{TN}}>;
255
+
256
+ let cRow: u32 = groupid.x;
257
+ let cCol: u32 = groupid.y;
258
+ let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
259
+
260
+ // position of the first c element computed by the thread
261
+ let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
262
+ let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
263
+
264
+ let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
265
+ let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
266
+
267
+ // aPtr and bPtr are the starting positions of the tiles in a and b,
268
+ // incremented in the bkidx loop.
269
+ // cPtr is the starting position of the tile in c which is fixed.
270
+
271
+ var aPtr = cRow * {{BM}} * {{K}};
272
+ var bPtr = cCol * {{BN}} * {{K}};
273
+ let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
274
+
275
+ for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
276
+
277
+ // Load tile
278
+ // Load BM x BK by numThread(BM * BN / (TM * TN))
279
+ // The number of iteration == BM * BK / (BM * BN / (TM * TN))
280
+ for (var i: u32 = 0; i < numIterA; i++) {
281
+ let loadColA: u32 = (localID.x + i * numThread) % {{BK}};
282
+ let loadRowA: u32 = (localID.x + i * numThread) / {{BK}};
283
+ tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA];
284
+ }
285
+ // Load BK x BN by numThread(BM * BN / (TM * TN))
286
+ // The number of iteration == BK * BN / (BM * BN / (TM * TN))
287
+ for (var i: u32 = 0; i < numIterB; i++) {
288
+ let loadColB: u32 = (localID.x + i * numThread) % {{BK}};
289
+ let loadRowB: u32 = (localID.x + i * numThread) / {{BK}};
290
+ tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB];
291
+ }
292
+
293
+ aPtr += {{BK}};
294
+ bPtr += {{BK}};
295
+
296
+ workgroupBarrier();
297
+ // Compute tile
298
+ for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
299
+ for (var i: u32 = 0; i < {{TM}}; i++) {
300
+ localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx];
301
+ }
302
+ for (var i: u32 = 0; i < {{TN}}; i++) {
303
+ localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx];
304
+ }
305
+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
306
+ for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
307
+ threadResults[resIdxM * {{TN}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
308
+ }
309
+ }
310
+ }
311
+ workgroupBarrier();
312
+ }
313
+
314
+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
315
+ for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
316
+ c[cPtr + (threadRow + resIdxM) * {{N}} + threadCol + resIdxN] = threadResults[resIdxM * {{TN}} + resIdxN];
317
+ }
318
+ }
319
+ }
320
+ )" ;
321
+
322
+ inline ShaderCode createMatmul4 (const char *shaderTemplate, const size_t M,
323
+ const size_t K, const size_t N, const size_t BM,
324
+ const size_t BK, const size_t BN,
325
+ const size_t TM, const size_t TN,
326
+ const Shape &workgroupSize = {256 , 1 , 1 },
327
+ NumType precision = kf32) {
328
+ assert (BM % TM == 0 );
329
+ assert (BN % TN == 0 );
330
+ assert (K % BK == 0 );
331
+ assert (M % BM == 0 );
332
+ assert (N % BN == 0 );
333
+ // # threads = tile A size == tile B size == # threads for computing C
334
+ // assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN);
335
+ // assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN));
336
+ std::string codeString (shaderTemplate);
337
+ replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
338
+ {" {{precision}}" , toString (precision)},
339
+ {" {{M}}" , toString (M)},
340
+ {" {{K}}" , toString (K)},
341
+ {" {{N}}" , toString (N)},
342
+ {" {{BM}}" , toString (BM)},
343
+ {" {{BK}}" , toString (BK)},
344
+ {" {{BN}}" , toString (BN)},
345
+ {" {{TM}}" , toString (TM)},
346
+ {" {{TN}}" , toString (TN)}});
347
+ return ShaderCode{codeString, workgroupSize};
348
+ }
349
+
236
350
inline ShaderCode createNoOp (const char *shaderTemplate,
237
351
const Shape &workgroupSize = {256 , 1 , 1 },
238
352
NumType precision = kf32) {
@@ -304,6 +418,22 @@ Kernel selectMatmul(Context &ctx, int version,
304
418
kernel = createKernel (ctx, matmul, bindings,
305
419
/* nWorkgroups*/ nWorkgroups);
306
420
} else if (version == 4 ) {
421
+ static constexpr size_t BM = 64 ;
422
+ static constexpr size_t BK = 16 ;
423
+ static constexpr size_t BN = 64 ;
424
+ static constexpr size_t TM = BM / BK;
425
+ static constexpr size_t TN = BN / BK;
426
+ Shape wgSize = {(BM / TM) * (BN / TN), 1 , 1 }; // This is the same as BK * BK.
427
+ Shape nWorkgroups = {cdiv (M, BM), cdiv (N, BN), 1 };
428
+ LOG (kDefLog , kInfo , " M: %d, K: %d, N: %d" , M, K, N);
429
+ LOG (kDefLog , kInfo , " BM: %d, BK: %d, BN: %d, TM: %d, TN: %d" , BM, BK, BN, TM, TN);
430
+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
431
+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
432
+ ShaderCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
433
+ /* wgSize*/ wgSize);
434
+ kernel = createKernel (ctx, matmul, bindings,
435
+ /* nWorkgroups*/ nWorkgroups);
436
+ } else if (version == 5 ) {
307
437
Shape wgSize = {256 , 1 , 1 };
308
438
Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
309
439
ShaderCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
@@ -371,10 +501,11 @@ void runTest(int version, size_t M, size_t K, size_t N,
371
501
}
372
502
373
503
int main () {
374
- int version = 3 ; // 1 == naive matmul
504
+ int version = 4 ; // 1 == naive matmul
375
505
// 2 == tiling
376
506
// 3 == 1D blocktiling
377
- // 4 == No-Op
507
+ // 4 == 2D blocktiling
508
+ // 5 == No-Op
378
509
size_t M, K, N; // Matrix dimensions
379
510
static constexpr int kTestSize = 2 ;
380
511
if constexpr (kTestSize == 0 ) {
0 commit comments