|
| 1 | +# Flattened GPT Forward, Removed Post Attention Lambdas, Added Transpose Kernels |
| 2 | + |
| 3 | +This PR has gone through a number of revisions as we've been catching up the repo with recent PRs, and as I continued to play with the new parallel residuals feature. |
| 4 | + |
| 5 | +There are three main changes included: |
| 6 | +1. Removed the Block and MLP classes |
| 7 | +2. Removed the post attention lambdas on the single lane layers |
| 8 | +3. A few kernel improvements |
| 9 | + |
| 10 | +``` |
| 11 | + Runs Steps Time μ Time σ Time +/- Loss μ Loss σ Loss +/- p |
| 12 | +Baseline 4 1490 89.7995 0.0621 0.0000 3.2782 0.0016 0.0000 0.0601 |
| 13 | +This PR 4 1490 88.8340 0.0445 -0.9655 3.2788 0.0005 0.0006 0.0097 |
| 14 | +
|
| 15 | +Baseline: |
| 16 | + losses = [3.2764, 3.2799, 3.2773, 3.2793] |
| 17 | + times = [89.8030, 89.7110, 89.8350, 89.8490] |
| 18 | +
|
| 19 | +This PR: |
| 20 | + losses = [3.2783, 3.2789, 3.2795, 3.2786] |
| 21 | + times = [88.8210, 88.8090, 88.8060, 88.9000] |
| 22 | +``` |
| 23 | + |
| 24 | +## Flattened GPT.forward |
| 25 | + |
| 26 | +The parallel residuals PR was implemented within GPT.forward rather than in Block. I took this further and removed the Block and MLP classes (and the module list) altogether, and I believe this accounts for roughly half the speed improvement in this PR. I don't know what that achieved, exactly, so it may be possible to restore those classes and move the logic to Block while keeping the speedup. I'm partial to inlining code, though, so I didn't look into it. :) |
| 27 | + |
| 28 | +## Post Attention Lambdas |
| 29 | + |
| 30 | +I tried a number of systems and architectural changes to the parallel residual streams. Several gave a speed-up, but at the cost of some validation loss. System improvements rarely seem to justify adding back steps. |
| 31 | + |
| 32 | +The only change I made which ended up being included here was to remove the post attention lambdas from the single lane layers, where they seemed like they might be redundant with the sa lambdas. |
| 33 | + |
| 34 | +Here's everything I tried. All of these may be worth trying again, since they may behave differently in the context of other changes to the code. The baseline changed a couple times over the course of these experiments, and in a couple cases I'm seeing different outcomes depending on the baseline. |
| 35 | + |
| 36 | +Here are things I tried: |
| 37 | + |
| 38 | +1. Moving the x0/bigram injection back before attention. |
| 39 | + * The previous PR moved this injection to after attention instead of before. |
| 40 | + * Switching it back increased loss slightly but improved speed. |
| 41 | + * I ended up reverting this in order to get under 3.28. |
| 42 | +2. Adding the x0/bigram injection to layer 6. |
| 43 | + * It wasn't being injected at layer 6, which seemed like maybe an oversight. |
| 44 | + * Adding it increased both loss and time. |
| 45 | +3. Pre-multiplying the post mlp lambdas into the MLP output projections on the single lane layers. |
| 46 | + * Similar to number 1, this gave a small speed up at the cost of a small increase in loss. |
| 47 | + * I had to abandon it to get under 3.28. |
| 48 | +4. Pre-multiplying the post attention lambdas into the attention output projection. |
| 49 | + * This was faster, but hurt loss significantly. |
| 50 | + * Dropping them altogether seemed to be a better solution. |
| 51 | + * In my last batch of experiments, I tried adding them back to help improve loss, but they actually increased it this time. |
| 52 | + |
| 53 | +Other stuff: |
| 54 | +* I did not try "untying" the residual lambdas. |
| 55 | +* The post mlp lambdas on the final layer seem to be ~equal, so we may be able to remove them and save some time. |
| 56 | + |
| 57 | +## Kernel Changes |
| 58 | + |
| 59 | +This PR includes a few kernel improvements: |
| 60 | +1. Transpose copy and transpose add kernels to speed up the lm head and embedding copies in the optimizer. |
| 61 | +2. Eliminated a backwards select kernel on MLP banks |
| 62 | +3. Transposed polar express kernel (from my previous abandoned PR) |
| 63 | +4. Fused Nesterov momentum calculation into polar express |
| 64 | + |
| 65 | + |
| 66 | +### 1 - Transpose Copy |
| 67 | + |
| 68 | +Transposing the LM head (in [PR200](https://github.com/KellerJordan/modded-nanogpt/pull/200)) eliminated an expensive gradient accumulation kernel that pytorch would run every step, but the benefit was partially negated because we still had to do some element-wise operations to artificially keep the embeddings and LM head tied. |
| 69 | + |
| 70 | +My rationale at the time was that this was fine because it would be overlapped with compute, but the trace files have had some glaring communication gaps caused by these steps. I learned it's because the kernels were using all of the SMs, preventing the GPUs from being able to work on communication at the same time. (The fact that communication speed relates to SM availability seems like an interesting insight). |
| 71 | + |
| 72 | +Claude was able to write efficient "transpose add" and "transpose copy" kernels which reduced the cost of these steps significantly. |
| 73 | + |
| 74 | +<img width="1912" height="671" alt="image" src="transpose_kernels.png" /> |
| 75 | + |
| 76 | +Side note--I finally got around to making a tutorial for my trace-reading workflow: |
| 77 | +https://www.youtube.com/watch?v=vTdLpaI5gMQ |
| 78 | + |
| 79 | +### 2 - "Select Backwards" Kernels |
| 80 | + |
| 81 | +I noticed recently that my/our efforts to avoid the select-backwards kernels caused by the parameter banks haven't been entirely working--the trace file is still full of them. |
| 82 | + |
| 83 | +Claude was able to eliminate these by changing the way in which the weights are accessed. |
| 84 | + |
| 85 | +Curiously, one of the three fixes had a clear, consistent negative impact despite being mathematically equivalent. It must be due to a difference in the compiler's choice of kernel, and the order of operations in it. I had to leave it out. |
| 86 | + |
| 87 | +Another of the three seemed like it might be hurting loss--less conclusive--but I left it out. |
| 88 | + |
| 89 | +We were able to fix what seemed to be the biggest offender though, surrounding the MLP weights, without impacting loss. |
| 90 | + |
| 91 | +### 3 - Polar Express Transpose |
| 92 | + |
| 93 | +(Copied from my closed PR) |
| 94 | + |
| 95 | +I've also included an improvement to the Polar Express kernel. It was written to work on "wide" matrices, but we've had our MLP weights stored vertically. The current Polar Express code handles this by transposing the matrix, but the algebra can be re-ordered instead to allow for working directly on "tall" matrices: |
| 96 | + |
| 97 | +``` |
| 98 | +# Polar express |
| 99 | +X = g.bfloat16() |
| 100 | +
|
| 101 | +X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) |
| 102 | +
|
| 103 | +if g.size(-2) > g.size(-1): # Tall matrix |
| 104 | + for a, b, c in polar_express_coeffs[:ns_steps]: |
| 105 | + A = X.mT @ X |
| 106 | + B = b * A + c * (A @ A) |
| 107 | + X = a * X + X @ B |
| 108 | +
|
| 109 | +else: # Wide matrix (original math) |
| 110 | + for a, b, c in polar_express_coeffs[:ns_steps]: |
| 111 | + A = X @ X.mT |
| 112 | + B = b * A + c * (A @ A) |
| 113 | + X = a * X + B @ X |
| 114 | +
|
| 115 | +return X |
| 116 | +``` |
| 117 | + |
| 118 | +I had Claude write an XTX variant for this. |
| 119 | + |
| 120 | +This change doesn't appear to impact the speed of the matrix multiplication itself. Instead, it's relevant to the element-wise momentum kernel preceeding it. The transposed polar express allows for a much faster non-transposed element-wise momentum kernel. |
| 121 | + |
| 122 | +### 4 - Fused Momentum Kernel |
| 123 | + |
| 124 | +In the nanochat variant of the optimizer, Karpathy combined all of the NorMuon steps (from nesterov momentum through the weight update) into a single compiled helper function, which results in more operations being fused. |
| 125 | + |
| 126 | +However, when I tried that on our optimizer, I saw the same fusion / reduction in kernel count, but I was getting slower times. It could relate to us using a triton kernel for polar express, or for our additional precision-enhancing mantissa storage technique. |
| 127 | + |
| 128 | +This time, I settled for fusing just the nesterov momentum and polar express, and that worked. |
| 129 | + |
| 130 | +I didn't try fusing the normuon variance reduction helper with the cautious weight decay + weight update helper. That might work as well. |
| 131 | + |
| 132 | +### Kernel Tests |
| 133 | + |
| 134 | +Just an AI disclosure: The above include three new triton kernels which Claude wrote and I didn't review / wouldn't know how to. Instead, I asked Claude to also write tests for them which passed, and of course the training seems to be working fine as well. |
| 135 | + |
| 136 | +## Loss Impacts without Math Changes |
| 137 | + |
| 138 | +A small side rant--I've encountered this enough times now to recognize it as common place in our model rather than an anomoly. Systems changes which don't affect the math can easily lead to increased validation loss. It seems that the placement of kernel boundaries and the order of operations in those kernels has the potential to create rounding errors that accumulate and have meaningful impact. |
| 139 | + |
| 140 | +On the plus side, I think this just as often has worked in my favor, where I make a systems optimization and loss 'inexplicably' improves as well. |
| 141 | + |
| 142 | +Just seemed worth sharing for the benefit of anyone else who tries something that increases loss even though it shouldn't! |
0 commit comments