|
| 1 | +#include <array> |
| 2 | +#include <chrono> |
| 3 | +#include <future> |
| 4 | +#include <random> |
| 5 | +#include <cstdlib> |
| 6 | + |
| 7 | +#include "gpu.h" // createContext, createTensor, createKernel, dispatchKernel, |
| 8 | + // wait, resetCommandBuffer, toCPU |
| 9 | + |
| 10 | +#include "llmc/reference_impls.h" // for CPU reference implementation |
| 11 | +#include "utils/array_utils.h" // show, isclose, randn, randint |
| 12 | +#include "utils/logging.h" // LOG |
| 13 | +#include "experimental/wgsl.h" // loopUnrolling |
| 14 | + |
| 15 | +using namespace gpu; |
| 16 | + |
| 17 | +// This implements the tranpose kernels in https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc . |
| 18 | + |
| 19 | +static const char *kShaderTranspose1 = R"( |
| 20 | +@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>; |
| 21 | +@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>; |
| 22 | +@compute @workgroup_size({{workgroupSize}}) |
| 23 | +fn main( |
| 24 | + @builtin(global_invocation_id) globalID : vec3<u32>) { |
| 25 | + let bRow: u32 = globalID.x; |
| 26 | + let bCol: u32 = globalID.y; |
| 27 | + B[bRow * {{M}} + bCol] = A[bCol * {{N}} + bRow]; |
| 28 | +} |
| 29 | +)"; |
| 30 | + |
| 31 | +inline KernelCode createTranspose1(const char *shaderTemplate, |
| 32 | + const size_t M, const size_t N, |
| 33 | + const Shape &workgroupSize = {256, 1, 1}, |
| 34 | + NumType precision = kf32) { |
| 35 | + std::string codeString(shaderTemplate); |
| 36 | + replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, |
| 37 | + {"{{precision}}", toString(precision)}, |
| 38 | + {"{{M}}", toString(M)}, |
| 39 | + {"{{N}}", toString(N)}}); |
| 40 | + return {codeString, workgroupSize}; |
| 41 | +} |
| 42 | + |
| 43 | +// Shared memory cache-blocking |
| 44 | +static const char *kShaderTranspose2 = R"( |
| 45 | +@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>; |
| 46 | +@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>; |
| 47 | +var<workgroup> tile: array<{{precision}}, {{BN}} * {{BM}}>; |
| 48 | +@compute @workgroup_size({{workgroupSize}}) |
| 49 | +fn main( |
| 50 | + @builtin(local_invocation_id) localID : vec3<u32>, |
| 51 | + @builtin(workgroup_id) groupID: vec3<u32>) { |
| 52 | + let bRow: u32 = groupID.x * {{BN}}; |
| 53 | + let bCol: u32 = groupID.y * {{BM}}; |
| 54 | +
|
| 55 | + let aPtr = bCol * {{N}} + bRow; |
| 56 | + let bPtr = bRow * {{M}} + bCol; |
| 57 | + let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}}); |
| 58 | +
|
| 59 | + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { |
| 60 | + for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { |
| 61 | + let idx: u32 = localID.x + numThread * (resIdxN + {{TN}} * resIdxM); |
| 62 | + let loadRow: u32 = idx / {{BN}}; |
| 63 | + let loadCol: u32 = idx % {{BN}}; |
| 64 | + tile[loadCol * {{BN}} + loadRow] = A[aPtr + loadRow * {{N}} + loadCol]; |
| 65 | + } |
| 66 | + } |
| 67 | +
|
| 68 | + workgroupBarrier(); |
| 69 | +
|
| 70 | + for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { |
| 71 | + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { |
| 72 | + let idx: u32 = localID.x + numThread * (resIdxM + {{TM}} * resIdxN); |
| 73 | + let loadRow: u32 = idx / {{BM}}; |
| 74 | + let loadCol: u32 = idx % {{BM}}; |
| 75 | + B[bPtr + loadRow * {{M}} + loadCol] = tile[loadRow * {{BM}} + loadCol]; |
| 76 | + } |
| 77 | + } |
| 78 | +} |
| 79 | +)"; |
| 80 | + |
| 81 | +inline KernelCode createTranspose2(const char *shaderTemplate, |
| 82 | + const size_t M, const size_t N, |
| 83 | + const size_t BM, const size_t BN, |
| 84 | + const size_t TM, const size_t TN, |
| 85 | + const Shape &workgroupSize = {256, 1, 1}, |
| 86 | + NumType precision = kf32) { |
| 87 | + assert(BM % TM == 0); |
| 88 | + assert(BN % TN == 0); |
| 89 | + assert(M % BM == 0); |
| 90 | + assert(N % BN == 0); |
| 91 | + std::string codeString(shaderTemplate); |
| 92 | + replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, |
| 93 | + {"{{precision}}", toString(precision)}, |
| 94 | + {"{{M}}", toString(M)}, |
| 95 | + {"{{N}}", toString(N)}, |
| 96 | + {"{{BM}}", toString(BM)}, |
| 97 | + {"{{BN}}", toString(BN)}, |
| 98 | + {"{{TM}}", toString(TM)}, |
| 99 | + {"{{TN}}", toString(TN)} |
| 100 | + }); |
| 101 | + std::string unrolledCode = codeString ;// loopUnrolling(codeString); |
| 102 | + return {unrolledCode, workgroupSize}; |
| 103 | +} |
| 104 | + |
| 105 | +void initData(size_t M, size_t N, std::unique_ptr<float[]> &inputPtr) { |
| 106 | + std::mt19937 gen(314159); |
| 107 | + randn(inputPtr.get(), M * N, gen); |
| 108 | + LOG(kDefLog, kInfo, "%s", show<float>(inputPtr.get(), M, N, "Input").c_str()); |
| 109 | +} |
| 110 | + |
| 111 | +Kernel selectTranspose(Context &ctx, int version, |
| 112 | + const Bindings</* input, output */ 2> &bindings, |
| 113 | + size_t M, size_t N) { |
| 114 | + Kernel kernel; |
| 115 | + if (version == 1) { |
| 116 | + Shape wgSize = {16, 16, 1}; |
| 117 | + LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str()); |
| 118 | + KernelCode transpose = |
| 119 | + createTranspose1(kShaderTranspose1, M, N, /*wgsize*/ wgSize); // The shape of input == M x N |
| 120 | + kernel = createKernel(ctx, transpose, bindings, |
| 121 | + /*nWorkgroups*/ cdiv({N, M, 1}, wgSize)); // The shape of output == N x M |
| 122 | + } else if (version == 2) { |
| 123 | + static constexpr size_t BM = 64; |
| 124 | + static constexpr size_t BK = 16; |
| 125 | + static constexpr size_t BN = 64; |
| 126 | + static constexpr size_t TM = BM / BK; |
| 127 | + static constexpr size_t TN = BN / BK; |
| 128 | + Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK. |
| 129 | + Shape nWorkgroups = {cdiv(N, BN), cdiv(M, BM), 1}; |
| 130 | + LOG(kDefLog, kInfo, "M: %d, N: %d", M, N); |
| 131 | + LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN); |
| 132 | + LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); |
| 133 | + LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); |
| 134 | + KernelCode transpose = createTranspose2(kShaderTranspose2, M, N, BM, BN, TM, TN, |
| 135 | + /*wgSize*/ wgSize, |
| 136 | + kf32); |
| 137 | + kernel = createKernel(ctx, transpose, bindings, |
| 138 | + /*nWorkgroups*/ nWorkgroups); |
| 139 | + } else if (version == 3) { |
| 140 | + LOG(kDefLog, kInfo, "Skip Creating Kernel", M, N); |
| 141 | + } |
| 142 | + return kernel; |
| 143 | +} |
| 144 | + |
| 145 | +void runTest(int version, size_t M, size_t N, |
| 146 | + std::unique_ptr<float[]> &inputPtr, |
| 147 | + std::unique_ptr<float[]> &outputPtr) { |
| 148 | + bool isCPU = version == 3; |
| 149 | + |
| 150 | + // Allocate GPU buffers and copy data |
| 151 | + Context ctx = createContext(); |
| 152 | + Tensor input = createTensor(ctx, Shape{M, N}, kf32, inputPtr.get()); |
| 153 | + Tensor output = createTensor(ctx, Shape{N, M}, kf32); |
| 154 | + |
| 155 | + constexpr size_t nIter = 50; |
| 156 | + |
| 157 | + // Initialize Kernel and bind GPU buffers |
| 158 | + LOG(kDefLog, kInfo, "Creating Kernel"); |
| 159 | + Kernel kernel = selectTranspose(ctx, version, {input, output}, M, N); |
| 160 | + |
| 161 | + // Dispatch kernel execution |
| 162 | + LOG(kDefLog, kInfo, "Dispatching Kernel version %d, %d iterations ...", |
| 163 | + version, nIter); |
| 164 | + |
| 165 | + // pre-allocate promises and futures for async dispatch |
| 166 | + // TODO(avh): implement a pooling mechanism for promises/futures in gpu.h |
| 167 | + std::array<std::promise<void>, nIter> promises; |
| 168 | + std::array<std::future<void>, nIter> futures; |
| 169 | + for (int i = 0; i < nIter; i++) { |
| 170 | + futures[i] = promises[i].get_future(); |
| 171 | + } |
| 172 | + |
| 173 | + // Dispatch kernel nIter times |
| 174 | + auto start = std::chrono::high_resolution_clock::now(); |
| 175 | + for (int i = 0; i < nIter; i++) { |
| 176 | + if (!isCPU) { |
| 177 | + dispatchKernel(ctx, kernel, promises[i]); |
| 178 | + wait(ctx, futures[i]); |
| 179 | + resetCommandBuffer(ctx.device, kernel); |
| 180 | + } else { |
| 181 | + transpose(inputPtr.get(), outputPtr.get(), M, N); |
| 182 | + } |
| 183 | + } |
| 184 | + auto end = std::chrono::high_resolution_clock::now(); |
| 185 | + |
| 186 | + // Report performance. |
| 187 | + // Use microsecond for more accurate time measurement |
| 188 | + auto duration = |
| 189 | + std::chrono::duration_cast<std::chrono::microseconds>(end - start); |
| 190 | + float gbps = sizeof(float) * M * N / |
| 191 | + (static_cast<double>(duration.count()) / 1000000.0) / |
| 192 | + 1000000000.0 * static_cast<float>(nIter); |
| 193 | + |
| 194 | + LOG(kDefLog, kInfo, "Copying result to CPU"); |
| 195 | + if (!isCPU) { |
| 196 | + toCPU(ctx, output, outputPtr.get(), M * N * sizeof(float)); |
| 197 | + } |
| 198 | + LOG(kDefLog, kInfo, "%s", |
| 199 | + show<float>(outputPtr.get(), N, M, "Output").c_str()); |
| 200 | + |
| 201 | + LOG(kDefLog, kInfo, "\n\n====================================================================" |
| 202 | + "============\nExecution Time: (M = %d, N = %d) x %d iterations " |
| 203 | + ":\n%.3f " |
| 204 | + "milliseconds / dispatch ~ %.2f " |
| 205 | + "GB/s\n================================================================" |
| 206 | + "================\n\n", |
| 207 | + M, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gbps); |
| 208 | +} |
| 209 | + |
| 210 | +int main() { |
| 211 | + char* version_str = getenv("TEST_VERSION"); |
| 212 | + int version = version_str == NULL ? 2 : atoi(version_str); |
| 213 | + // 1 == naive transpose |
| 214 | + // 2 == tiling |
| 215 | + // 3 == cpu |
| 216 | + |
| 217 | + size_t M, N; // Matrix dimensions |
| 218 | + static constexpr int kTestSize = 2; |
| 219 | + if constexpr (kTestSize == 0) { |
| 220 | + // Tiny test |
| 221 | + M = 16; |
| 222 | + N = 32; |
| 223 | + } else if constexpr (kTestSize == 1) { |
| 224 | + // Small test |
| 225 | + M = 256; |
| 226 | + N = 512; |
| 227 | + } else { |
| 228 | + // Large test |
| 229 | + M = 4096; |
| 230 | + N = 2 * 4096; |
| 231 | + } |
| 232 | + |
| 233 | + std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * N); |
| 234 | + std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(N * M); |
| 235 | + |
| 236 | + initData(M, N, inputPtr); |
| 237 | + runTest(version, M, N, inputPtr, outputPtr); |
| 238 | + |
| 239 | + LOG(kDefLog, kInfo, "Done."); |
| 240 | + return 0; |
| 241 | +} |
0 commit comments