Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 130 additions & 150 deletions src/tilegym/ops/cutile/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: MIT

import math

import cuda.tile as ct
import torch
import torch.nn as nn
Expand All @@ -14,127 +12,6 @@
from .utils import next_power_of_2


@experimental_kernel
@ct.kernel(occupancy=2)
def rms_norm_backward_kernel(
dx,
dy,
x,
weight,
Rstd,
temp_buffer,
TILE_SIZE: ct.Constant[int],
):
"""
Compute input gradients for RMSNorm backward pass.

Formula: dx_{m,i} = dy_{m,i} w_i / r_m - x_{m,i} / (N r_m^3) * sum_j dy_{m,j} w_j x_{m,j}
where:
- dy_{m,i} = dy[m,i] (upstream gradient)
- w_i = weight[i] (scale parameter)
- r_m = 1 / rstd[m] (RMS for row m)
- N = number of columns

See rms_norm_backward_annotated() for detailed derivation.

Each block handles exactly one row and processes all columns at once.
TILE_SIZE should be >= N (number of columns).
"""
row_idx = ct.bid(0)
M, N = x.shape

# Load entire row from input and gradient
input_row = ct.load(x, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO)
gradient_row = ct.load(dy, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO)

# Load reciprocal std (1D tensor [M]) and reshape for broadcasting
inv_std_row = ct.load(Rstd, index=(row_idx,), shape=(1,), padding_mode=ct.PaddingMode.ZERO)
inv_std_row = ct.reshape(inv_std_row, (1, 1)) # Reshape to [1, 1] for broadcasting

# Load weight vector and reshape for broadcasting
weight_vector = ct.load(weight, index=(0,), shape=(TILE_SIZE,), padding_mode=ct.PaddingMode.ZERO)
weight_vector = ct.reshape(weight_vector, (1, TILE_SIZE)) # Reshape to [1, TILE_SIZE] for broadcasting

# Compute sum_j dy_{m,j} w_j x_{m,j} for the correction term

c1 = input_row * gradient_row
c2 = c1 * inv_std_row

ct.store(temp_buffer, index=(row_idx, 0), tile=ct.astype(c2, temp_buffer.dtype))

weighted_gradient_product = c1 * weight_vector
weighted_gradient_sum = ct.sum(weighted_gradient_product, axis=1, keepdims=True) # [1, 1]

# Compute normalization correction: x_{m,i} / (N r_m^3) * sum_j dy_{m,j} w_j x_{m,j}
# Since inv_std_row = 1/r_m, we have r_m^3 = 1/(inv_std_row^3)
inv_std_cubed = inv_std_row * inv_std_row * inv_std_row # [1, 1]
norm_factor = ct.full((1, 1), N * 1.0, dtype=ct.float32) # [1, 1]
normalization_correction_coeff = input_row * inv_std_cubed / norm_factor # [1, TILE_SIZE]
normalization_correction = normalization_correction_coeff * weighted_gradient_sum # [1, TILE_SIZE]

# Compute direct term: dy_{m,i} w_i / r_m = gradient_row * weight_vector * inv_std_row
scaled_gradient = gradient_row * weight_vector * inv_std_row # [1, TILE_SIZE]

# Final dx: direct term minus normalization correction
input_gradient_row = scaled_gradient - normalization_correction # [1, TILE_SIZE]

# Convert back to the original dtype of dx
input_gradient_row = ct.astype(input_gradient_row, dx.dtype)

# Store the result back to dx
ct.store(dx, index=(row_idx, 0), tile=input_gradient_row)


def rms_norm_backward(
x: torch.Tensor,
dy: torch.Tensor,
weight: torch.Tensor,
rstd: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.contiguous()
dy = dy.contiguous()
weight = weight.contiguous()
rstd = rstd.contiguous()

x_shape = x.shape

# Flatten to [M, N]
x = x.reshape(-1, x.shape[-1])
dy = dy.reshape(-1, dy.shape[-1])

M, N = x.shape

# Allocate outputs
dx = torch.empty_like(x)
dw = torch.empty_like(weight) # shape (N,)
temp_buffer = torch.empty(x.shape, device=x.device, dtype=torch.float32)

dx = dx.detach()
dw = dw.detach()

TILE_SIZE_N = next_power_of_2(N)

# dx (row-parallel) algorithim
# Also stores dy * x / rms into temp_buffer for each row
grid_dx = (M,)
ct.launch(
torch.cuda.current_stream(),
grid_dx,
rms_norm_backward_kernel,
(dx, dy, x, weight, rstd, temp_buffer, TILE_SIZE_N),
)

# Compute dw by summing temp_buffer over the batch dimension
# temp_buffer contains: dy_{b,j} * x_{b,j} / rms_b (shape [M, N])
# dw_j = sum_b(dy_{b,j} * x_{b,j} / rms_b) * weight_j
# temp_buffer already has dy * x * rstd, so we just sum over row dim (torch performance would be the same as cuTILE)
# Ensure accumulates are done in float32 to avoid precision issues
dw = temp_buffer[:, :N].to(torch.float32).sum(dim=0).to(weight.dtype)

# Reshape dx back, dw already correct
return dx.view(*x_shape), dw


@ct.kernel
def rms_norm_kernel_gather(
x,
Expand Down Expand Up @@ -184,6 +61,7 @@ def rms_norm_kernel_static_persistent(
X, # Input tensor
Y, # Output tensor
W, # Weight tensor
Rstd, # rstd output (for backward)
TILE_SIZE_M: ct.Constant[int], # rows per tile
TILE_SIZE_N: ct.Constant[int], # columns per tile
eps: ct.Constant[float], # Epsilon value
Expand Down Expand Up @@ -238,6 +116,9 @@ def rms_norm_kernel_static_persistent(
variance_eps = ct.add(variance, eps_tensor)
rsqrt_var = ct.rsqrt(variance_eps)

# Store rstd for backward pass
ct.store(Rstd, index=(current_bid,), tile=ct.reshape(rsqrt_var, (TILE_SIZE_M,)), allow_tma=False)

# Step 5: Apply normalization
x_normalized = ct.mul(x, rsqrt_var)

Expand Down Expand Up @@ -265,6 +146,105 @@ def rms_norm_kernel_static_persistent(
)


@experimental_kernel
@ct.kernel(occupancy=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CONTRIBUTING.md, new kernels should carry the @experimental_kernel decorator. Please add the @experimental_kernel decorator to the new _rms_bwd kernel as was done with the content you deleted. We will remove the experimental marker once its functional correctness and performance have been fully validated. Thank you for your understanding.

def _rms_bwd(dx, dy, x, weight, Rstd, dw_partial, TILE_M: ct.Constant[int], TILE_N: ct.Constant[int]):
"""
Persistent RMSNorm backward — grid-stride loop with fused dw accumulation.

Each block accumulates its dw contribution into a (grid, TILE_N) partial
sum buffer, avoiding the old M×N temp_buffer allocation.

Only supports offset=0 (Gemma3 backward is not supported).
"""
bid = ct.bid(0)
M, N = x.shape[0], x.shape[1]
blocks = ct.num_blocks(0)
upper = (M + TILE_M - 1) // TILE_M

w = ct.astype(ct.load(weight, index=(0,), shape=(TILE_N,), padding_mode=ct.PaddingMode.ZERO), ct.float32)
w = ct.reshape(w, (1, TILE_N))
rcp = ct.full((TILE_M, 1), 1.0 / N, dtype=ct.float32)
dw_acc = ct.full((1, TILE_N), 0.0, dtype=ct.float32)

for i in range(bid, upper, blocks):
xt = ct.astype(
ct.load(x, index=(i, 0), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO, latency=10),
ct.float32,
)
dyt = ct.astype(
ct.load(dy, index=(i, 0), shape=(TILE_M, TILE_N), padding_mode=ct.PaddingMode.ZERO, latency=10),
ct.float32,
)
r = ct.reshape(
ct.load(Rstd, index=(i,), shape=(TILE_M,), padding_mode=ct.PaddingMode.ZERO),
(TILE_M, 1),
)
xhat = xt * r
wdy = dyt * w
c = ct.sum(xhat * wdy, axis=1, keepdims=True) * rcp
ct.store(dx, index=(i, 0), tile=ct.astype((wdy - xhat * c) * r, dx.dtype), allow_tma=False, latency=3)
dw_acc = dw_acc + ct.sum(dyt * xhat, axis=0, keepdims=True)

ct.store(dw_partial, index=(bid, 0), tile=dw_acc, allow_tma=False)


_bwd_cfg: dict = {} # (M, N) → (tile_m, tile_n, grid, N)


def _bwd_tiles(M, N):
"""Heuristic tile configuration for backward kernel."""
T = next_power_of_2(N)
if T > 4096:
tm = 1
elif T <= 2048 or (M >= 8192 and T <= 4096):
tm = 4
else:
tm = 1
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
tiles = (M + tm - 1) // tm
g = min(NUM_SMS, tiles)
if tiles <= 64:
g = min(g, 32)
return (tm, T, g, N)


def rms_norm_backward(
x: torch.Tensor,
dy: torch.Tensor,
weight: torch.Tensor,
rstd: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Standalone backward pass using persistent CuTile kernel."""
x = x.contiguous()
dy = dy.contiguous()
weight = weight.contiguous()
rstd = rstd.contiguous()

x_shape = x.shape
x = x.reshape(-1, x.shape[-1])
dy = dy.reshape(-1, dy.shape[-1])
M, N = x.shape

cfg = _bwd_cfg.get((M, N))
if cfg is None:
cfg = _bwd_tiles(M, N)
_bwd_cfg[(M, N)] = cfg
tm, T, g, No = cfg

stream = torch.cuda.current_stream()

dx = torch.empty_like(x)
dwp = torch.empty((g, T), device=x.device, dtype=torch.float32)
ct.launch(stream, (g,), _rms_bwd, (dx, dy, x, weight, rstd, dwp, tm, T))

dw = dwp.sum(0)
if T != No:
dw = dw[:No]

return dx.view(*x_shape), dw.to(weight.dtype)


class RMSNorm(torch.autograd.Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -319,6 +299,9 @@ def forward(
else:
static_persistent = False

# Allocate rstd for backward (both paths now store it)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)

if static_persistent:
# Static persistent mode
if bias is not None:
Expand All @@ -341,27 +324,24 @@ def ceil_div(a, b):
ceil_div(M, TILE_SIZE_M) * ceil_div(N, TILE_SIZE_N),
)
grid = (grid_size,)
kernel_sp = rms_norm_kernel_static_persistent
ct.launch(
torch.cuda.current_stream(),
grid,
kernel_sp,
(x_arg, y, weight, TILE_SIZE_M, TILE_SIZE_N, eps, offset),
rms_norm_kernel_static_persistent,
(x_arg, y, weight, rstd, TILE_SIZE_M, TILE_SIZE_N, eps, offset),
)
else:
# Standard mode
if bias is not None:
raise NotImplementedError("Bias is not supported in standard CuTile RMSNorm")

rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
MAX_FUSED_SIZE = 4096 // x.element_size()
TILE_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(N))
grid = (M,)
kernel = rms_norm_kernel_gather
ct.launch(
torch.cuda.current_stream(),
grid,
kernel,
rms_norm_kernel_gather,
(
x_arg,
weight,
Expand All @@ -374,30 +354,30 @@ def ceil_div(a, b):
),
)

# Save variables needed for backward pass
ctx.save_for_backward(x, weight, rstd)
ctx.TILE_SIZE = TILE_SIZE
ctx.eps = eps
ctx.offset = offset
# Always save for backward (both paths now produce rstd)
ctx.save_for_backward(x, weight, rstd)
ctx.TILE_SIZE = next_power_of_2(N)
ctx.eps = eps
ctx.offset = offset

return y.view(*x.shape)

@staticmethod
def backward(ctx, dy):
"""
Backward pass for RMSNorm.
Retrieves saved tensors and delegates to rms_norm_backward().
Persistent backward pass using grid-stride kernel.
Supports backward from both gather and static persistent forward modes.
"""
# Check if offset was used (backward not supported with non-zero offset)
if ctx.offset != 0.0:
raise NotImplementedError("Backward pass not implemented for CuTile RMSNorm with non-zero offset")
raise NotImplementedError(
f"Backward pass not implemented for CuTile RMSNorm with non-zero offset ({ctx.offset})"
)

x, weight, rstd = ctx.saved_tensors

# Call the standalone backward function
dx, dw = rms_norm_backward(x, dy, weight, rstd)

# Return gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset)
# Gradients: (x, normalized_shape, weight, eps, bias, static_persistent, offset)
return dx, None, dw, None, None, None, None


Expand Down Expand Up @@ -501,18 +481,18 @@ def rms_norm_backward_torch(
# Reshape rstd for broadcasting: (M,) -> (M, 1)
rstd = rstd.view(M, 1)

# Gradient w.r.t. weight: sum over batch dimension (accumulate in float32)
# Match kernel order: (x * dy) * rstd to match precision behavior
dw = ((x * dy) * rstd).sum(dim=0, dtype=torch.float32)
# Cast to fp32 up front so all intermediates are full precision
x_f = x.float()
dy_f = dy.float()
w_f = weight.float()

# Normalized x (before scaling by weight) - for dx computation
x_norm = x * rstd
# Gradient w.r.t. weight: dw = sum((x * rstd) * dy, dim=0)
x_norm = x_f * rstd
dw = (dy_f * x_norm).sum(dim=0)

# Gradient w.r.t. x (accumulate in float32)
dy_weighted = dy * weight
c1 = (dy_weighted * x_norm).sum(
dim=1, keepdim=True, dtype=torch.float32
) # ensure accumulates are done in float32 to avoid precision issues
# Gradient w.r.t. x
dy_weighted = dy_f * w_f
c1 = (dy_weighted * x_norm).sum(dim=1, keepdim=True)
dx = rstd * (dy_weighted - x_norm * c1 / N)

dx = dx.view(x_shape).to(x.dtype)
Expand Down
7 changes: 6 additions & 1 deletion tests/benchmark/experimental/bench_rmsnorm_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from tilegym.backend import is_backend_available
from tilegym.ops.cutile.rms_norm import TileRMSNorm
from tilegym.ops.cutile.rms_norm import _bwd_tiles
from tilegym.ops.cutile.rms_norm import rms_norm_backward

DEVICE = triton.runtime.driver.active.get_active_torch_device()
Expand Down Expand Up @@ -116,7 +117,11 @@ def run_backward():
dx_bytes = x.numel() * bytes_per_element # Write dx
dw_bytes = weight.numel() * bytes_per_element # Write dw

temp_buffer_bytes = x.numel() * 4 * 2 # always write + read float32
if backend == "cutile":
_, tile_n, grid, _ = _bwd_tiles(M, N)
temp_buffer_bytes = grid * tile_n * 4 * 2 # partial-sum buffer: write + read float32
else:
temp_buffer_bytes = 0

total_bytes = input_x_bytes + dy_bytes + weight_bytes + rstd_bytes + dx_bytes + dw_bytes + temp_buffer_bytes

Expand Down
Loading