Skip to content

feat: replace RMSNorm backward with persistent CuTile kernel#60

Merged
hannahli-nv merged 6 commits intoNVIDIA:mainfrom
aghilann:benchmark-rmsnorm-comparison
Feb 21, 2026
Merged

feat: replace RMSNorm backward with persistent CuTile kernel#60
hannahli-nv merged 6 commits intoNVIDIA:mainfrom
aghilann:benchmark-rmsnorm-comparison

Conversation

@aghilann
Copy link
Contributor

@aghilann aghilann commented Feb 15, 2026

Description

Replaces the old one-row-per-block RMSNorm backward kernel with a persistent grid-stride kernel that fuses dw accumulation into a compact (grid × TILE_N) partial-sum buffer instead of allocating an M×N temp buffer. I wrote this kernel and was able to get it to exceed the performance of Liger's Triton Kernels and get quite close to the performance of Quacks CuteDSL Kernel (which I'm somewhat assuming is near peak-performance)

Changes:

  • Backward kernel: New persistent _rms_bwd with grid-stride loop and fused dw accumulation
  • Forward kernels: Unchanged (gather + static persistent), except persistent now also stores rstd so backward works from both modes
  • Reference fix: rms_norm_backward_torch now casts to fp32 upfront, matching kernel precision
  • Removed unused aliases and imports

Backward kernel throughput (GB/s) — standalone, M=4096

N Old Version New Version PyTorch Old→New vs PyTorch
bf16
1024 1,534 3,020 482 2.0x 6.3x
2048 2,274 6,709 549 2.9x 12.2x
4096 2,823 6,100 538 2.2x 11.3x
8192 3,652 8,008 552 2.2x 14.5x
16384 3,762 4,135 573 1.1x 7.2x
fp16
1024 1,561 3,197 483 2.0x 6.6x
2048 2,368 4,454 553 1.9x 8.1x
4096 2,933 6,131 540 2.1x 11.4x
8192 3,578 7,987 553 2.2x 14.4x
16384 4,020 4,332 574 1.1x 7.5x
fp32
1024 2,268 4,940 955 2.2x 5.2x
2048 2,923 7,357 1,099 2.5x 6.7x
4096 3,901 7,140 1,030 1.8x 6.9x
8192 3,667 9,298 1,051 2.5x 8.8x
16384 2,634 6,943 1,090 2.6x 6.4x

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aghilann aghilann changed the title feat: re-write RMSNorm backward kernel for greatly improved performance feat: replace RMSNorm backward with persistent CuTile kernel Feb 15, 2026
@aghilann
Copy link
Contributor Author

aghilann commented Feb 15, 2026

Could I get a review when you get a chance @hannahli-nv !

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 17, 2026

/ok to test 4d947d3

@aghilann
Copy link
Contributor Author

@hannahli-nv Sorry to double ping, wanted to get this reviewed soon since I have a few more on the way with good performance :)

@hannahli-nv
Copy link
Collaborator

The existing Test_RMSNormBackward tests the standalone rms_norm_backward() with pytorch-computed rstd. It does not exercise the RMSNorm.backward() autograd path, nor does it verify that the forward kernels store rstd correctly. Could you please add a Test_RMSNormAutogradBackward class in test_rmsnorm_backward.py that runs rms_norm() forward with requires_grad=True inputs and validates gradients via autograd?

if not dy2.is_contiguous():
dy2 = dy2.contiguous()

cfg = _bwd_cfg.get((M, N))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the kernel launch logic was mostly copy-pasted between RMSNorm.backward() and the standalone rms_norm_backward() function. Is it possible to call rms_norm_backward() instead of inlining the same codes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies, I should have caught this earlier.

@hannahli-nv
Copy link
Collaborator

bench_rmsnorm_backward.py still assumes the old M x N temp buffer instead of the partial-sum buffer. Could you please update it to reflect the new partial-sum buffer size, as this will ensure the correct GBps improvement.

)


@ct.kernel(occupancy=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CONTRIBUTING.md, new kernels should carry the @experimental_kernel decorator. Please add the @experimental_kernel decorator to the new _rms_bwd kernel as was done with the content you deleted. We will remove the experimental marker once its functional correctness and performance have been fully validated. Thank you for your understanding.

@hannahli-nv
Copy link
Collaborator

@hannahli-nv Sorry to double ping, wanted to get this reviewed soon since I have a few more on the way with good performance :)

Sorry for the delayed response, as I was out of the office over the past few days. I have added a few comments. Thanks again for your excellent optimization!

@aghilann
Copy link
Contributor Author

aghilann commented Feb 21, 2026

Hey @hannahli-nv, I’ve made the requested changes. Thank you for the thorough review. I saw the updated CONTRIBUTING.MD, I'm excited for the roadmap! Also, I was running this on a consumer Blackwell chip since I don't have a B200 with CUDA 13.1 on hand, so the benchmarks in the PR description might need to be updated due to the update to the partial-buffer size calculation change.

@aghilann aghilann requested a review from hannahli-nv February 21, 2026 07:20
@hannahli-nv
Copy link
Collaborator

/ok to test abb1556

root added 6 commits February 21, 2026 17:37
…on benchmark

Forward kernels (gather + static persistent) remain unchanged except
the persistent kernel now also stores rstd so backward works from both modes.

Backward: replaced old one-row-per-block approach (M×N temp buffer) with
Bastile's grid-stride persistent kernel (grid × TILE_N partial sums for dw).

- Both forward modes now support backward (previously only gather did)
- Removed unused ConstInt/ConstFloat/PAD_ZERO aliases, import math, experimental_kernel
- Added bench_rmsnorm_tilegym_vs_bastile.py comparison benchmark
- All 8 correctness tests pass, benchmark numbers unchanged
…cision

The rms_norm_backward_torch reference was computing x*dy in bf16/fp16
before casting to fp32, losing precision. The CuTile kernel correctly
operates in fp32 throughout. Fixed reference to cast to fp32 upfront
so both agree. All 13 tests now pass (5 experimental backward + 8 fwd+bwd).
@hannahli-nv hannahli-nv force-pushed the benchmark-rmsnorm-comparison branch from abb1556 to 32c2a51 Compare February 21, 2026 09:37
@hannahli-nv
Copy link
Collaborator

/ok to test 32c2a51

Copy link
Collaborator

@hannahli-nv hannahli-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thx!

@hannahli-nv hannahli-nv merged commit 932d623 into NVIDIA:main Feb 21, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants