|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Export and validate chunk_gated_delta_rule triton kernel on CUDA backend. |
| 9 | +
|
| 10 | +Requires: pip install flash-linear-attention |
| 11 | +
|
| 12 | +Usage: |
| 13 | + python -m pytest backends/cuda/tests/test_chunk_gated_delta_rule.py -v |
| 14 | +
|
| 15 | + # Standalone export (produces .pte + .ptd): |
| 16 | + python backends/cuda/tests/test_chunk_gated_delta_rule.py --output-dir /tmp/exports |
| 17 | +""" |
| 18 | + |
| 19 | +import argparse |
| 20 | +import os |
| 21 | +import subprocess |
| 22 | +import sys |
| 23 | +import tempfile |
| 24 | +import unittest |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import torch |
| 28 | +import torch.nn.functional as F |
| 29 | +from torch.export import export |
| 30 | + |
| 31 | +try: |
| 32 | + import fla # noqa: F401 |
| 33 | + |
| 34 | + HAS_FLA = True |
| 35 | +except ImportError: |
| 36 | + HAS_FLA = False |
| 37 | + |
| 38 | +if HAS_FLA: |
| 39 | + import executorch.backends.cuda.triton.kernels.chunk_gated_delta_rule # noqa: F401 |
| 40 | + |
| 41 | +from executorch.backends.cuda.cuda_backend import CudaBackend |
| 42 | +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner |
| 43 | +from executorch.exir import ( |
| 44 | + EdgeCompileConfig, |
| 45 | + ExecutorchBackendConfig, |
| 46 | + to_edge_transform_and_lower, |
| 47 | +) |
| 48 | +from executorch.exir.passes import MemoryPlanningPass |
| 49 | + |
| 50 | + |
| 51 | +B, T, H, K, V = 1, 128, 4, 64, 64 |
| 52 | + |
| 53 | +EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../..")) |
| 54 | +RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner") |
| 55 | + |
| 56 | +# Test configurations adapted from FLA's test_gated_delta.py test_chunk() |
| 57 | +# Format: (seed, gate_logit_normalizer, mask_p, nonzero_h0, description) |
| 58 | +FLA_TEST_CONFIGS = [ |
| 59 | + # Basic configs varying gate normalizer |
| 60 | + (42, 1.0, 0.0, False, "basic_norm1"), |
| 61 | + (123, 0.1, 0.0, False, "strong_gate"), |
| 62 | + (7, 10.0, 0.0, False, "weak_gate"), |
| 63 | + # Non-zero initial state |
| 64 | + (42, 1.0, 0.0, True, "nonzero_h0_norm1"), |
| 65 | + (99, 0.1, 0.0, True, "nonzero_h0_strong"), |
| 66 | + (55, 10.0, 0.0, True, "nonzero_h0_weak"), |
| 67 | + # Sparse gating (50% of gates masked to zero) |
| 68 | + (42, 1.0, 0.5, False, "sparse_gate_50pct"), |
| 69 | + (77, 0.1, 0.5, True, "sparse_strong_h0"), |
| 70 | + # Different random patterns |
| 71 | + (0, 1.0, 0.0, False, "seed0"), |
| 72 | + (100, 1.0, 0.0, True, "seed100_h0"), |
| 73 | + (2024, 0.5, 0.0, False, "norm0.5"), |
| 74 | + (999, 5.0, 0.3, True, "norm5_sparse30_h0"), |
| 75 | + # Edge-ish values |
| 76 | + (13, 0.01, 0.0, False, "very_strong_gate"), |
| 77 | + (31, 100.0, 0.0, False, "very_weak_gate"), |
| 78 | + (64, 1.0, 0.9, True, "sparse_90pct_h0"), |
| 79 | +] |
| 80 | + |
| 81 | + |
| 82 | +class ChunkGatedDeltaModel(torch.nn.Module): |
| 83 | + def forward(self, q, k, v, g, beta, initial_state): |
| 84 | + q = F.normalize(q, p=2, dim=-1) |
| 85 | + k = F.normalize(k, p=2, dim=-1) |
| 86 | + o, final_state = torch.ops.triton.chunk_gated_delta_rule( |
| 87 | + q, k, v, g, beta, initial_state |
| 88 | + ) |
| 89 | + return o, final_state |
| 90 | + |
| 91 | + |
| 92 | +def _make_inputs_from_fla( |
| 93 | + seed, |
| 94 | + gate_logit_normalizer, |
| 95 | + mask_p=0.0, |
| 96 | + nonzero_h0=False, |
| 97 | + dtype=torch.bfloat16, |
| 98 | + device="cuda", |
| 99 | +): |
| 100 | + """Generate inputs following FLA test_chunk() conventions.""" |
| 101 | + torch.manual_seed(seed) |
| 102 | + q = torch.rand(B, T, H, K, dtype=dtype, device=device) |
| 103 | + k = torch.rand(B, T, H, K, dtype=dtype, device=device) |
| 104 | + v = torch.rand(B, T, H, V, dtype=dtype, device=device) |
| 105 | + beta = torch.rand(B, T, H, dtype=torch.float32, device=device).sigmoid().to(dtype) |
| 106 | + g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32, device=device)) |
| 107 | + g = (g / gate_logit_normalizer).to(dtype) |
| 108 | + if mask_p > 0: |
| 109 | + g = g * (torch.rand(B, T, H, dtype=dtype, device=device) > mask_p) |
| 110 | + if nonzero_h0: |
| 111 | + h0 = torch.randn(B, H, K, V, dtype=dtype, device=device) |
| 112 | + else: |
| 113 | + h0 = torch.zeros(B, H, K, V, dtype=dtype, device=device) |
| 114 | + return q, k, v, g, beta, h0 |
| 115 | + |
| 116 | + |
| 117 | +def _make_inputs(dtype=torch.bfloat16, device="cuda"): |
| 118 | + q = torch.randn(B, T, H, K, dtype=dtype, device=device) |
| 119 | + k = torch.randn(B, T, H, K, dtype=dtype, device=device) |
| 120 | + v = torch.randn(B, T, H, V, dtype=dtype, device=device) |
| 121 | + g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)) |
| 122 | + beta = torch.rand(B, T, H, dtype=dtype, device=device).sigmoid() |
| 123 | + initial_state = torch.randn(B, H, K, V, dtype=dtype, device=device) |
| 124 | + return q, k, v, g, beta, initial_state |
| 125 | + |
| 126 | + |
| 127 | +def _save_tensor(t, path): |
| 128 | + t_cpu = t.cpu().contiguous() |
| 129 | + with open(path, "wb") as f: |
| 130 | + f.write(bytes(t_cpu.untyped_storage())) |
| 131 | + |
| 132 | + |
| 133 | +def _load_output(path, shape, dtype): |
| 134 | + data = np.fromfile(path, dtype=np.uint8) |
| 135 | + return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape) |
| 136 | + |
| 137 | + |
| 138 | +def export_chunk_gated_delta(output_dir): |
| 139 | + model = ChunkGatedDeltaModel().eval() |
| 140 | + inputs = _make_inputs() |
| 141 | + |
| 142 | + with torch.no_grad(): |
| 143 | + ref_o, ref_s = model(*inputs) |
| 144 | + print(f"Eager output shape: {ref_o.shape}, final_state shape: {ref_s.shape}") |
| 145 | + |
| 146 | + with torch.no_grad(): |
| 147 | + ep = export(model, inputs, strict=True) |
| 148 | + print("Export OK") |
| 149 | + |
| 150 | + os.makedirs(output_dir, exist_ok=True) |
| 151 | + |
| 152 | + specs = [CudaBackend.generate_method_name_compile_spec("forward")] |
| 153 | + et_prog = to_edge_transform_and_lower( |
| 154 | + ep, |
| 155 | + partitioner=[CudaPartitioner(specs)], |
| 156 | + compile_config=EdgeCompileConfig( |
| 157 | + _check_ir_validity=False, _skip_dim_order=True |
| 158 | + ), |
| 159 | + ) |
| 160 | + et_program = et_prog.to_executorch( |
| 161 | + config=ExecutorchBackendConfig( |
| 162 | + extract_delegate_segments=True, |
| 163 | + do_quant_fusion_and_const_prop=True, |
| 164 | + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), |
| 165 | + ), |
| 166 | + ) |
| 167 | + |
| 168 | + pte_path = os.path.join(output_dir, "chunk_gated_delta.pte") |
| 169 | + with open(pte_path, "wb") as f: |
| 170 | + et_program.write_to_file(f) |
| 171 | + |
| 172 | + if hasattr(et_program, "_tensor_data") and et_program._tensor_data: |
| 173 | + et_program.write_tensor_data_to_file(output_dir) |
| 174 | + |
| 175 | + print(f"Saved to {pte_path} ({os.path.getsize(pte_path) / 1024:.0f} KB)") |
| 176 | + return pte_path |
| 177 | + |
| 178 | + |
| 179 | +def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base): |
| 180 | + """Run executor_runner and return subprocess result.""" |
| 181 | + cmd = [ |
| 182 | + runner_path, |
| 183 | + f"--model_path={pte_path}", |
| 184 | + f"--data_path={ptd_path}", |
| 185 | + f"--inputs={','.join(input_files)}", |
| 186 | + f"--output_file={output_base}", |
| 187 | + ] |
| 188 | + result = subprocess.run(cmd, capture_output=True, text=True) |
| 189 | + return result |
| 190 | + |
| 191 | + |
| 192 | +@unittest.skipUnless(HAS_FLA, "flash-linear-attention not installed") |
| 193 | +class TestChunkGatedDeltaRule(unittest.TestCase): |
| 194 | + def setUp(self): |
| 195 | + if not torch.cuda.is_available(): |
| 196 | + self.skipTest("CUDA is not available") |
| 197 | + |
| 198 | + def test_eager(self): |
| 199 | + model = ChunkGatedDeltaModel().eval() |
| 200 | + inputs = _make_inputs() |
| 201 | + with torch.no_grad(): |
| 202 | + o, s = model(*inputs) |
| 203 | + self.assertEqual(o.shape, torch.Size([B, T, H, V])) |
| 204 | + self.assertEqual(s.shape, torch.Size([B, H, K, V])) |
| 205 | + self.assertEqual(o.dtype, torch.bfloat16) |
| 206 | + self.assertEqual(s.dtype, torch.float32) |
| 207 | + |
| 208 | + def test_eager_fla_configs(self): |
| 209 | + """Run FLA-style test configurations and verify against naive reference.""" |
| 210 | + from fla.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule |
| 211 | + |
| 212 | + model = ChunkGatedDeltaModel().eval() |
| 213 | + for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS: |
| 214 | + with self.subTest(desc=desc): |
| 215 | + inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0) |
| 216 | + q, k, v, g, beta, h0 = inputs |
| 217 | + |
| 218 | + with torch.no_grad(): |
| 219 | + o_ours, s_ours = model(q, k, v, g, beta, h0) |
| 220 | + |
| 221 | + o_ref, s_ref = naive_recurrent_gated_delta_rule( |
| 222 | + q=F.normalize(q, p=2, dim=-1), |
| 223 | + k=F.normalize(k, p=2, dim=-1), |
| 224 | + v=v, |
| 225 | + beta=beta, |
| 226 | + g=g, |
| 227 | + initial_state=h0, |
| 228 | + output_final_state=True, |
| 229 | + ) |
| 230 | + |
| 231 | + o_diff = (o_ours.float() - o_ref.float()).abs().max().item() |
| 232 | + s_diff = (s_ours.float() - s_ref.float()).abs().max().item() |
| 233 | + self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}") |
| 234 | + self.assertLess(s_diff, 0.01, f"{desc}: state diff {s_diff}") |
| 235 | + |
| 236 | + def test_eager_matches_fla(self): |
| 237 | + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_impl |
| 238 | + |
| 239 | + torch.manual_seed(42) |
| 240 | + inputs = _make_inputs() |
| 241 | + q, k, v, g, beta, h0 = inputs |
| 242 | + |
| 243 | + q_norm = F.normalize(q, p=2, dim=-1) |
| 244 | + k_norm = F.normalize(k, p=2, dim=-1) |
| 245 | + with torch.no_grad(): |
| 246 | + o_ours, _ = torch.ops.triton.chunk_gated_delta_rule( |
| 247 | + q_norm, k_norm, v, g, beta, h0 |
| 248 | + ) |
| 249 | + o_ref, _ = fla_impl( |
| 250 | + q, |
| 251 | + k, |
| 252 | + v, |
| 253 | + g, |
| 254 | + beta, |
| 255 | + initial_state=h0, |
| 256 | + output_final_state=True, |
| 257 | + use_qk_l2norm_in_kernel=True, |
| 258 | + ) |
| 259 | + |
| 260 | + self.assertLess((o_ours.float() - o_ref.float()).abs().max().item(), 0.01) |
| 261 | + |
| 262 | + def test_export_cuda(self): |
| 263 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 264 | + pte_path = export_chunk_gated_delta(tmpdir) |
| 265 | + self.assertTrue(os.path.exists(pte_path)) |
| 266 | + self.assertGreater(os.path.getsize(pte_path), 0) |
| 267 | + |
| 268 | + def test_e2e_cpp_runner(self): |
| 269 | + self.assertTrue( |
| 270 | + os.path.exists(RUNNER_PATH), |
| 271 | + f"executor_runner not found at {RUNNER_PATH}. " |
| 272 | + "Build with: cmake --build cmake-out --target executor_runner", |
| 273 | + ) |
| 274 | + """Export, run executor_runner with FLA test inputs, compare with eager.""" |
| 275 | + model = ChunkGatedDeltaModel().eval() |
| 276 | + |
| 277 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 278 | + export_dir = os.path.join(tmpdir, "export") |
| 279 | + pte_path = export_chunk_gated_delta(export_dir) |
| 280 | + ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd") |
| 281 | + |
| 282 | + for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS: |
| 283 | + with self.subTest(desc=desc): |
| 284 | + inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0) |
| 285 | + q, k, v, g, beta, h0 = inputs |
| 286 | + |
| 287 | + with torch.no_grad(): |
| 288 | + ref_o, ref_s = model(q, k, v, g, beta, h0) |
| 289 | + |
| 290 | + run_dir = os.path.join(tmpdir, f"run_{desc}") |
| 291 | + os.makedirs(run_dir) |
| 292 | + |
| 293 | + input_files = [] |
| 294 | + for i, tensor in enumerate(inputs): |
| 295 | + path = os.path.join(run_dir, f"{i}.bin") |
| 296 | + _save_tensor(tensor, path) |
| 297 | + input_files.append(path) |
| 298 | + |
| 299 | + output_base = os.path.join(run_dir, "output") |
| 300 | + result = _run_cpp_runner( |
| 301 | + RUNNER_PATH, pte_path, ptd_path, input_files, output_base |
| 302 | + ) |
| 303 | + self.assertEqual( |
| 304 | + result.returncode, |
| 305 | + 0, |
| 306 | + f"{desc}: executor_runner failed:\n{result.stderr}", |
| 307 | + ) |
| 308 | + |
| 309 | + cpp_o = _load_output( |
| 310 | + f"{output_base}-0.bin", |
| 311 | + (B, T, H, V), |
| 312 | + torch.bfloat16, |
| 313 | + ) |
| 314 | + cpp_s = _load_output( |
| 315 | + f"{output_base}-1.bin", |
| 316 | + (B, H, K, V), |
| 317 | + torch.float32, |
| 318 | + ) |
| 319 | + |
| 320 | + o_diff = (cpp_o.float() - ref_o.cpu().float()).abs().max().item() |
| 321 | + s_diff = (cpp_s.float() - ref_s.cpu().float()).abs().max().item() |
| 322 | + self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}") |
| 323 | + self.assertLess(s_diff, 0.1, f"{desc}: state diff {s_diff}") |
| 324 | + |
| 325 | + |
| 326 | +if __name__ == "__main__": |
| 327 | + parser = argparse.ArgumentParser() |
| 328 | + parser.add_argument("--output-dir", default=None) |
| 329 | + args, remaining = parser.parse_known_args() |
| 330 | + |
| 331 | + if args.output_dir: |
| 332 | + export_chunk_gated_delta(args.output_dir) |
| 333 | + else: |
| 334 | + sys.argv = [sys.argv[0]] + remaining |
| 335 | + unittest.main() |
0 commit comments