feat: replace RMSNorm backward with persistent CuTile kernel#60
feat: replace RMSNorm backward with persistent CuTile kernel#60hannahli-nv merged 6 commits intoNVIDIA:mainfrom
Conversation
|
Could I get a review when you get a chance @hannahli-nv ! |
|
/ok to test 4d947d3 |
|
@hannahli-nv Sorry to double ping, wanted to get this reviewed soon since I have a few more on the way with good performance :) |
|
The existing Test_RMSNormBackward tests the standalone |
src/tilegym/ops/cutile/rms_norm.py
Outdated
| if not dy2.is_contiguous(): | ||
| dy2 = dy2.contiguous() | ||
|
|
||
| cfg = _bwd_cfg.get((M, N)) |
There was a problem hiding this comment.
It seems that the kernel launch logic was mostly copy-pasted between RMSNorm.backward() and the standalone rms_norm_backward() function. Is it possible to call rms_norm_backward() instead of inlining the same codes?
There was a problem hiding this comment.
My apologies, I should have caught this earlier.
|
bench_rmsnorm_backward.py still assumes the old M x N temp buffer instead of the partial-sum buffer. Could you please update it to reflect the new partial-sum buffer size, as this will ensure the correct GBps improvement. |
| ) | ||
|
|
||
|
|
||
| @ct.kernel(occupancy=1) |
There was a problem hiding this comment.
Per CONTRIBUTING.md, new kernels should carry the @experimental_kernel decorator. Please add the @experimental_kernel decorator to the new _rms_bwd kernel as was done with the content you deleted. We will remove the experimental marker once its functional correctness and performance have been fully validated. Thank you for your understanding.
Sorry for the delayed response, as I was out of the office over the past few days. I have added a few comments. Thanks again for your excellent optimization! |
|
Hey @hannahli-nv, I’ve made the requested changes. Thank you for the thorough review. I saw the updated CONTRIBUTING.MD, I'm excited for the roadmap! Also, I was running this on a consumer Blackwell chip since I don't have a B200 with CUDA 13.1 on hand, so the benchmarks in the PR description might need to be updated due to the update to the partial-buffer size calculation change. |
|
/ok to test abb1556 |
…on benchmark Forward kernels (gather + static persistent) remain unchanged except the persistent kernel now also stores rstd so backward works from both modes. Backward: replaced old one-row-per-block approach (M×N temp buffer) with Bastile's grid-stride persistent kernel (grid × TILE_N partial sums for dw). - Both forward modes now support backward (previously only gather did) - Removed unused ConstInt/ConstFloat/PAD_ZERO aliases, import math, experimental_kernel - Added bench_rmsnorm_tilegym_vs_bastile.py comparison benchmark - All 8 correctness tests pass, benchmark numbers unchanged
…cision The rms_norm_backward_torch reference was computing x*dy in bf16/fp16 before casting to fp32, losing precision. The CuTile kernel correctly operates in fp32 throughout. Fixed reference to cast to fp32 upfront so both agree. All 13 tests now pass (5 experimental backward + 8 fwd+bwd).
abb1556 to
32c2a51
Compare
|
/ok to test 32c2a51 |
Description
Replaces the old one-row-per-block RMSNorm backward kernel with a persistent grid-stride kernel that fuses
dwaccumulation into a compact(grid × TILE_N)partial-sum buffer instead of allocating anM×Ntemp buffer. I wrote this kernel and was able to get it to exceed the performance of Liger's Triton Kernels and get quite close to the performance of Quacks CuteDSL Kernel (which I'm somewhat assuming is near peak-performance)Changes:
_rms_bwdwith grid-stride loop and fuseddwaccumulationrstdso backward works from both modesrms_norm_backward_torchnow casts to fp32 upfront, matching kernel precisionBackward kernel throughput (GB/s) — standalone, M=4096
CI Configuration
Checklist
./format.sh)