Skip to content

Commit 2958804

Browse files
mergennachinclaude
andcommitted
Add chunk_gated_delta_rule triton kernel for CUDA backend
Use executor_runner (the portable generic runner) instead of a per-kernel C++ runner. Enable executor_runner in the llm CMake preset so it links the CUDA backend transitively via executorch_backends. Co-authored-by: Claude <noreply@anthropic.com>
1 parent 7134708 commit 2958804

File tree

5 files changed

+623
-1
lines changed

5 files changed

+623
-1
lines changed

.github/workflows/cuda.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ jobs:
126126
cmake --workflow --preset default
127127
popd
128128
129-
# Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini
129+
# Install flash-linear-attention for chunk_gated_delta_rule triton kernel tests
130+
pip install "flash-linear-attention==0.4.2"
131+
132+
# Build executor_runner (needed by CUDA backend e2e tests)
133+
cmake --build cmake-out --target executor_runner
134+
135+
# Run all CUDA backend Python tests (including chunk_gated_delta e2e)
130136
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
131137
132138
export-model-cuda-artifact:
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Export and validate chunk_gated_delta_rule triton kernel on CUDA backend.
9+
10+
Requires: pip install flash-linear-attention
11+
12+
Usage:
13+
python -m pytest backends/cuda/tests/test_chunk_gated_delta_rule.py -v
14+
15+
# Standalone export (produces .pte + .ptd):
16+
python backends/cuda/tests/test_chunk_gated_delta_rule.py --output-dir /tmp/exports
17+
"""
18+
19+
import argparse
20+
import os
21+
import subprocess
22+
import sys
23+
import tempfile
24+
import unittest
25+
26+
import numpy as np
27+
import torch
28+
import torch.nn.functional as F
29+
from torch.export import export
30+
31+
try:
32+
import fla # noqa: F401
33+
34+
HAS_FLA = True
35+
except ImportError:
36+
HAS_FLA = False
37+
38+
if HAS_FLA:
39+
import executorch.backends.cuda.triton.kernels.chunk_gated_delta_rule # noqa: F401
40+
41+
from executorch.backends.cuda.cuda_backend import CudaBackend
42+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
43+
from executorch.exir import (
44+
EdgeCompileConfig,
45+
ExecutorchBackendConfig,
46+
to_edge_transform_and_lower,
47+
)
48+
from executorch.exir.passes import MemoryPlanningPass
49+
50+
51+
B, T, H, K, V = 1, 128, 4, 64, 64
52+
53+
EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../.."))
54+
RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner")
55+
56+
# Test configurations adapted from FLA's test_gated_delta.py test_chunk()
57+
# Format: (seed, gate_logit_normalizer, mask_p, nonzero_h0, description)
58+
FLA_TEST_CONFIGS = [
59+
# Basic configs varying gate normalizer
60+
(42, 1.0, 0.0, False, "basic_norm1"),
61+
(123, 0.1, 0.0, False, "strong_gate"),
62+
(7, 10.0, 0.0, False, "weak_gate"),
63+
# Non-zero initial state
64+
(42, 1.0, 0.0, True, "nonzero_h0_norm1"),
65+
(99, 0.1, 0.0, True, "nonzero_h0_strong"),
66+
(55, 10.0, 0.0, True, "nonzero_h0_weak"),
67+
# Sparse gating (50% of gates masked to zero)
68+
(42, 1.0, 0.5, False, "sparse_gate_50pct"),
69+
(77, 0.1, 0.5, True, "sparse_strong_h0"),
70+
# Different random patterns
71+
(0, 1.0, 0.0, False, "seed0"),
72+
(100, 1.0, 0.0, True, "seed100_h0"),
73+
(2024, 0.5, 0.0, False, "norm0.5"),
74+
(999, 5.0, 0.3, True, "norm5_sparse30_h0"),
75+
# Edge-ish values
76+
(13, 0.01, 0.0, False, "very_strong_gate"),
77+
(31, 100.0, 0.0, False, "very_weak_gate"),
78+
(64, 1.0, 0.9, True, "sparse_90pct_h0"),
79+
]
80+
81+
82+
class ChunkGatedDeltaModel(torch.nn.Module):
83+
def forward(self, q, k, v, g, beta, initial_state):
84+
q = F.normalize(q, p=2, dim=-1)
85+
k = F.normalize(k, p=2, dim=-1)
86+
o, final_state = torch.ops.triton.chunk_gated_delta_rule(
87+
q, k, v, g, beta, initial_state
88+
)
89+
return o, final_state
90+
91+
92+
def _make_inputs_from_fla(
93+
seed,
94+
gate_logit_normalizer,
95+
mask_p=0.0,
96+
nonzero_h0=False,
97+
dtype=torch.bfloat16,
98+
device="cuda",
99+
):
100+
"""Generate inputs following FLA test_chunk() conventions."""
101+
torch.manual_seed(seed)
102+
q = torch.rand(B, T, H, K, dtype=dtype, device=device)
103+
k = torch.rand(B, T, H, K, dtype=dtype, device=device)
104+
v = torch.rand(B, T, H, V, dtype=dtype, device=device)
105+
beta = torch.rand(B, T, H, dtype=torch.float32, device=device).sigmoid().to(dtype)
106+
g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32, device=device))
107+
g = (g / gate_logit_normalizer).to(dtype)
108+
if mask_p > 0:
109+
g = g * (torch.rand(B, T, H, dtype=dtype, device=device) > mask_p)
110+
if nonzero_h0:
111+
h0 = torch.randn(B, H, K, V, dtype=dtype, device=device)
112+
else:
113+
h0 = torch.zeros(B, H, K, V, dtype=dtype, device=device)
114+
return q, k, v, g, beta, h0
115+
116+
117+
def _make_inputs(dtype=torch.bfloat16, device="cuda"):
118+
q = torch.randn(B, T, H, K, dtype=dtype, device=device)
119+
k = torch.randn(B, T, H, K, dtype=dtype, device=device)
120+
v = torch.randn(B, T, H, V, dtype=dtype, device=device)
121+
g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device))
122+
beta = torch.rand(B, T, H, dtype=dtype, device=device).sigmoid()
123+
initial_state = torch.randn(B, H, K, V, dtype=dtype, device=device)
124+
return q, k, v, g, beta, initial_state
125+
126+
127+
def _save_tensor(t, path):
128+
t_cpu = t.cpu().contiguous()
129+
with open(path, "wb") as f:
130+
f.write(bytes(t_cpu.untyped_storage()))
131+
132+
133+
def _load_output(path, shape, dtype):
134+
data = np.fromfile(path, dtype=np.uint8)
135+
return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape)
136+
137+
138+
def export_chunk_gated_delta(output_dir):
139+
model = ChunkGatedDeltaModel().eval()
140+
inputs = _make_inputs()
141+
142+
with torch.no_grad():
143+
ref_o, ref_s = model(*inputs)
144+
print(f"Eager output shape: {ref_o.shape}, final_state shape: {ref_s.shape}")
145+
146+
with torch.no_grad():
147+
ep = export(model, inputs, strict=True)
148+
print("Export OK")
149+
150+
os.makedirs(output_dir, exist_ok=True)
151+
152+
specs = [CudaBackend.generate_method_name_compile_spec("forward")]
153+
et_prog = to_edge_transform_and_lower(
154+
ep,
155+
partitioner=[CudaPartitioner(specs)],
156+
compile_config=EdgeCompileConfig(
157+
_check_ir_validity=False, _skip_dim_order=True
158+
),
159+
)
160+
et_program = et_prog.to_executorch(
161+
config=ExecutorchBackendConfig(
162+
extract_delegate_segments=True,
163+
do_quant_fusion_and_const_prop=True,
164+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
165+
),
166+
)
167+
168+
pte_path = os.path.join(output_dir, "chunk_gated_delta.pte")
169+
with open(pte_path, "wb") as f:
170+
et_program.write_to_file(f)
171+
172+
if hasattr(et_program, "_tensor_data") and et_program._tensor_data:
173+
et_program.write_tensor_data_to_file(output_dir)
174+
175+
print(f"Saved to {pte_path} ({os.path.getsize(pte_path) / 1024:.0f} KB)")
176+
return pte_path
177+
178+
179+
def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
180+
"""Run executor_runner and return subprocess result."""
181+
cmd = [
182+
runner_path,
183+
f"--model_path={pte_path}",
184+
f"--data_path={ptd_path}",
185+
f"--inputs={','.join(input_files)}",
186+
f"--output_file={output_base}",
187+
]
188+
result = subprocess.run(cmd, capture_output=True, text=True)
189+
return result
190+
191+
192+
@unittest.skipUnless(HAS_FLA, "flash-linear-attention not installed")
193+
class TestChunkGatedDeltaRule(unittest.TestCase):
194+
def setUp(self):
195+
if not torch.cuda.is_available():
196+
self.skipTest("CUDA is not available")
197+
198+
def test_eager(self):
199+
model = ChunkGatedDeltaModel().eval()
200+
inputs = _make_inputs()
201+
with torch.no_grad():
202+
o, s = model(*inputs)
203+
self.assertEqual(o.shape, torch.Size([B, T, H, V]))
204+
self.assertEqual(s.shape, torch.Size([B, H, K, V]))
205+
self.assertEqual(o.dtype, torch.bfloat16)
206+
self.assertEqual(s.dtype, torch.float32)
207+
208+
def test_eager_fla_configs(self):
209+
"""Run FLA-style test configurations and verify against naive reference."""
210+
from fla.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule
211+
212+
model = ChunkGatedDeltaModel().eval()
213+
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
214+
with self.subTest(desc=desc):
215+
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0)
216+
q, k, v, g, beta, h0 = inputs
217+
218+
with torch.no_grad():
219+
o_ours, s_ours = model(q, k, v, g, beta, h0)
220+
221+
o_ref, s_ref = naive_recurrent_gated_delta_rule(
222+
q=F.normalize(q, p=2, dim=-1),
223+
k=F.normalize(k, p=2, dim=-1),
224+
v=v,
225+
beta=beta,
226+
g=g,
227+
initial_state=h0,
228+
output_final_state=True,
229+
)
230+
231+
o_diff = (o_ours.float() - o_ref.float()).abs().max().item()
232+
s_diff = (s_ours.float() - s_ref.float()).abs().max().item()
233+
self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}")
234+
self.assertLess(s_diff, 0.01, f"{desc}: state diff {s_diff}")
235+
236+
def test_eager_matches_fla(self):
237+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_impl
238+
239+
torch.manual_seed(42)
240+
inputs = _make_inputs()
241+
q, k, v, g, beta, h0 = inputs
242+
243+
q_norm = F.normalize(q, p=2, dim=-1)
244+
k_norm = F.normalize(k, p=2, dim=-1)
245+
with torch.no_grad():
246+
o_ours, _ = torch.ops.triton.chunk_gated_delta_rule(
247+
q_norm, k_norm, v, g, beta, h0
248+
)
249+
o_ref, _ = fla_impl(
250+
q,
251+
k,
252+
v,
253+
g,
254+
beta,
255+
initial_state=h0,
256+
output_final_state=True,
257+
use_qk_l2norm_in_kernel=True,
258+
)
259+
260+
self.assertLess((o_ours.float() - o_ref.float()).abs().max().item(), 0.01)
261+
262+
def test_export_cuda(self):
263+
with tempfile.TemporaryDirectory() as tmpdir:
264+
pte_path = export_chunk_gated_delta(tmpdir)
265+
self.assertTrue(os.path.exists(pte_path))
266+
self.assertGreater(os.path.getsize(pte_path), 0)
267+
268+
def test_e2e_cpp_runner(self):
269+
self.assertTrue(
270+
os.path.exists(RUNNER_PATH),
271+
f"executor_runner not found at {RUNNER_PATH}. "
272+
"Build with: cmake --build cmake-out --target executor_runner",
273+
)
274+
"""Export, run executor_runner with FLA test inputs, compare with eager."""
275+
model = ChunkGatedDeltaModel().eval()
276+
277+
with tempfile.TemporaryDirectory() as tmpdir:
278+
export_dir = os.path.join(tmpdir, "export")
279+
pte_path = export_chunk_gated_delta(export_dir)
280+
ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd")
281+
282+
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
283+
with self.subTest(desc=desc):
284+
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0)
285+
q, k, v, g, beta, h0 = inputs
286+
287+
with torch.no_grad():
288+
ref_o, ref_s = model(q, k, v, g, beta, h0)
289+
290+
run_dir = os.path.join(tmpdir, f"run_{desc}")
291+
os.makedirs(run_dir)
292+
293+
input_files = []
294+
for i, tensor in enumerate(inputs):
295+
path = os.path.join(run_dir, f"{i}.bin")
296+
_save_tensor(tensor, path)
297+
input_files.append(path)
298+
299+
output_base = os.path.join(run_dir, "output")
300+
result = _run_cpp_runner(
301+
RUNNER_PATH, pte_path, ptd_path, input_files, output_base
302+
)
303+
self.assertEqual(
304+
result.returncode,
305+
0,
306+
f"{desc}: executor_runner failed:\n{result.stderr}",
307+
)
308+
309+
cpp_o = _load_output(
310+
f"{output_base}-0.bin",
311+
(B, T, H, V),
312+
torch.bfloat16,
313+
)
314+
cpp_s = _load_output(
315+
f"{output_base}-1.bin",
316+
(B, H, K, V),
317+
torch.float32,
318+
)
319+
320+
o_diff = (cpp_o.float() - ref_o.cpu().float()).abs().max().item()
321+
s_diff = (cpp_s.float() - ref_s.cpu().float()).abs().max().item()
322+
self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}")
323+
self.assertLess(s_diff, 0.1, f"{desc}: state diff {s_diff}")
324+
325+
326+
if __name__ == "__main__":
327+
parser = argparse.ArgumentParser()
328+
parser.add_argument("--output-dir", default=None)
329+
args, remaining = parser.parse_known_args()
330+
331+
if args.output_dir:
332+
export_chunk_gated_delta(args.output_dir)
333+
else:
334+
sys.argv = [sys.argv[0]] + remaining
335+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@
99
__all__ = [
1010
"sdpa",
1111
]
12+
13+
try:
14+
from executorch.backends.cuda.triton.kernels.chunk_gated_delta_rule import ( # noqa: F401
15+
chunk_gated_delta_rule,
16+
)
17+
18+
__all__.append("chunk_gated_delta_rule")
19+
except ImportError:
20+
pass

0 commit comments

Comments
 (0)