Non-record: Fused Triton Megakernels — RMSNorm + LeakyReLU² (val_bpb 1.3560)#1192
Non-record: Fused Triton Megakernels — RMSNorm + LeakyReLU² (val_bpb 1.3560)#1192dentity007 wants to merge 3 commits intoopenai:mainfrom
Conversation
…er optimization, and SSM exploration
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Research Expansion: Ablation ResultsRan an overnight ablation study on DGX Spark GB10 to expand on this submission. Note: Triton kernels do not work on aarch64 (Spark is ARM), so these runs test the PyTorch-equivalent configurations. The actual kernel speedup would be additive on H100. 200 training steps, sp1024, no torch.compile. Results
FindingMEGA-2 with d=640 beats MEGA-3 with 11 layers despite both adding compute. Wider is better than deeper for this model size. The practical implication: if megakernel fusion saves X% training time, reinvest that time as width (more channels) rather than depth (more layers). This suggests an architecture-first optimization strategy: find the best width/depth tradeoff first, then layer on kernel fusion as a free speedup. Full raw data and logs: https://gist.github.com/dentity007/324ac35505c27acd18e7ffb468f4fa08 |
Community Review — Non-record: Fused Triton Megakernels — RMSNorm + LeakyReLU² (val_bpb 1.3560)Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache Analysis PR #1192 — "Megakernels_FusedTriton" submission under
|
|
Thanks for the audit. Appreciate the specific callout that the two Quick note for context: the Triton kernels in this submission dispatch on |
Non-record: Fused Triton Megakernels - RMSNorm + LeakyReLU Squared
val_bpb: 1.3560 | 1x RTX 5090 Ada 16GB, 600s wallclock | sp1024
Implements OpenAI's requested "Megakernels" research direction.
Architecture
Results
Key Findings
Fused kernels provide a small but real BPB improvement. The -0.0017 BPB gain comes from faster evaluation, allowing slightly more training iterations within the wallclock budget. This is a pure systems optimization, not an ML improvement.
RMSNorm fusion eliminates a kernel launch. The standard F.rms_norm involves multiple small operations. The fused Triton kernel does normalization in a single pass, reducing launch overhead.
LeakyReLU squared is a good fusion target. The activation function involves three operations (leaky_relu, square, multiply). Fusing them avoids materializing intermediate tensors.
Training-time kernel use is limited by torch.compile. fullgraph=True mode does not support custom Triton kernels inside compiled regions. The kernels are only used for the eval pass.
Comparison to Naive Baseline
Note: Both use the same 9-layer config, not the 11-layer record config. The absolute BPB difference from 1.2244 is due to different hyperparameters, not the kernels.
Reproduction
Discussion
Megakernels represent the "systems" side of the challenge. While -0.0017 BPB is small, it is free performance that stacks with any ML improvement. The bigger potential is fusing the full attention block (Q/K/V projection + RoPE + attention + output projection) into a single kernel, which could save significantly more launch overhead. This would require writing a more complex Triton kernel but the parameter-golf model is small enough that launch overhead is a meaningful fraction of compute.
Would welcome collaboration on more aggressive fusion targets.
Credits
Script:
train_gpt_megakernel.pyImplements OpenAI's requested "Megakernels" direction from the README.