[FEA] Add FlashAttention backward pass implementation, test cases and benchmark#49
Conversation
- Add fmha_fwd_kernel_with_lse: Forward kernel that saves log-sum-exp (LSE) to context for backward pass, compatible with torch.autograd - Add fmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO) - Add fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA support - Add fmha_bwd_dq_kernel: Computes dQ gradient - Add FlashAttentionFunction: torch.autograd.Function wrapper that saves LSE during forward and uses it in backward - Add tile_fmha_with_backward: Public API for training with autograd support - Add tile_fmha_functional: Auto-selects inference or training mode - Add autotune configs for backward kernels
- Add Test_FMHA_Backward class with test cases for: - Regular shapes (power of 2): various batch/head/seq/dim combinations - Irregular shapes (non-power of 2): odd/prime sequence lengths - Corner cases: single batch, single head, many heads, short sequences - Forward with LSE matches reference - Functional API inference and training modes - Add Test_FMHA_Backward_GQA class for Grouped Query Attention tests: - GQA ratios: 2:1, 4:1, 8:1 (Llama-style) - Multi-Query Attention (all Q heads share 1 KV head) - Add Test_FMHA_Backward_Numerical for torch.autograd.gradcheck - All test functions follow test_op_* naming convention
- Add bench_attention_backward.py: Benchmarks fwd/bwd/fwd+bwd modes - Compares CuTile against PyTorch SDPA backends (Flash, MemEff, Math) - Tests d=64 and d=128 head dimensions - Tests causal and non-causal modes - Plot names include -GBps suffix for CI recognition
|
Hi @Weili-0234, thank you for this great contribution! The current performance data looks promising. Could you clarify whether these numbers are from the causal or non-causal kernel? To get a fuller picture, could you provide a more comprehensive comparison against the Triton tutorial's Fused Attention? Specifically, it would be helpful to see the benchmarks for: d={64, 128} and Causal={True, False}. |
|
/ok to test ac49a28 |
Thank you @qelk123 for the kind suggestion, I have updated the reported results accordingly, clearly denoting causality. Also I notice the test_ops CI failed due to autotuning taking too long (exceeding the 10 min limit), should I just submit a version with less tuning space? But this would make the current implementation more under-tuned and less performant. What do you suggest? As for comparison with triton, I can share a self-contained script below if you think it's helpful. The only reason I didn't contain the official triton tutorial version of flash attention is that I dont want to introduce extra file dependency into the TileGym repo. |
|
Hi @Weili-0234, thanks for the comparison data! We’d be happy to accept your contribution once the CI passes. To resolve the test-op CI timeout, please implement the following: Introduce DISABLE_AUTOTUNE: Add this environment variable to your implementation. When enabled, it should bypass the autotuning process and default to the 0th configuration. CI Configuration: Enable DISABLE_AUTOTUNE for the test_op pytest suite to save time, but do not enable it for benchmark to ensure performance results remain accurate. Looking forward to the update! |
|
Hi @xjmxyt, thanks for the kind suggestion. I just implemented the Would you be able to trigger the CI again? Happy to address anything else if needed. Thanks! |
|
/ok to test 0668aae |
|
Hi, thanks for re-triggering the CI! I'll fix the formatting issue in my next commit. Regarding the test-ops timeout failure: I checked the logs and found that the CI timed out while still running existing tests like Thanks for your patience in advance! |
|
@Weili-0234 I see the logs, I think your tests has run. |
|
Hi, thanks for your kind observation and suggestion. Sorry that I didn't look at the logs carefully. I've reduced the number of test cases as requested. Would you be able to trigger the CI again? Happy to address anything else if needed. Thanks! |
|
/ok to test dc93e6e |
|
/ok to test eecde06 |
… benchmark (NVIDIA#49) * feat: Add Flash Attention backward pass implementation - Add fmha_fwd_kernel_with_lse: Forward kernel that saves log-sum-exp (LSE) to context for backward pass, compatible with torch.autograd - Add fmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO) - Add fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA support - Add fmha_bwd_dq_kernel: Computes dQ gradient - Add FlashAttentionFunction: torch.autograd.Function wrapper that saves LSE during forward and uses it in backward - Add tile_fmha_with_backward: Public API for training with autograd support - Add tile_fmha_functional: Auto-selects inference or training mode - Add autotune configs for backward kernels * test: Add comprehensive tests for Flash Attention backward pass - Add Test_FMHA_Backward class with test cases for: - Regular shapes (power of 2): various batch/head/seq/dim combinations - Irregular shapes (non-power of 2): odd/prime sequence lengths - Corner cases: single batch, single head, many heads, short sequences - Forward with LSE matches reference - Functional API inference and training modes - Add Test_FMHA_Backward_GQA class for Grouped Query Attention tests: - GQA ratios: 2:1, 4:1, 8:1 (Llama-style) - Multi-Query Attention (all Q heads share 1 KV head) - Add Test_FMHA_Backward_Numerical for torch.autograd.gradcheck - All test functions follow test_op_* naming convention * bench: Add Flash Attention backward benchmarks - Add bench_attention_backward.py: Benchmarks fwd/bwd/fwd+bwd modes - Compares CuTile against PyTorch SDPA backends (Flash, MemEff, Math) - Tests d=64 and d=128 head dimensions - Tests causal and non-causal modes - Plot names include -GBps suffix for CI recognition * add formatting via ./format.sh * additional formatting * Introduce DISABLE_AUTOTUNE as requested in PR NVIDIA#49 reviewers * fix flash attn bwd formatting * reduce test cases to avoid CI timeout
Description
This PR adds backward pass support for FlashAttention, enabling training with the fused attention kernels. The implementation saves log-sum-exp (LSE) in the autograd context during forward pass (therefore being compatible with PyTorch's autograd). The backward pass computes gradients dQ, dK, dV using a standard Flash Attention backward algorithm.
Related Issue
NVIDIA/cutile-python#15
Implementation
Code Changes
attention.py: Added training-mode forward and backward kernelsfmha_fwd_kernel_with_lse: Forward kernel that saves LSE to context for backwardfmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO)fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA supportfmha_bwd_dq_kernel: Computes dQ gradientFlashAttentionFunction:torch.autograd.Functionwrappertile_fmha_with_backward: Public API for training with autograd supporttile_fmha_functional: Auto-selects inference or training modeGradient Math
Testing
Correctness Verification
Similar to
tests/ops/test_rmsnorm_backward.py, test classes inherit fromcommon.PyTestCaseand methods are namedtest_op_*. Thereference_backward()helper computes reference gradients viatorch.nn.functional.scaled_dot_product_attentionwith autograd, then eachtest_op_*method usestorch.testing.assert_close()to compare cuTile implementation's dQ/dK/dV against the PyTorch reference. GQA tests additionally expand K/V heads and sum gradients to match the grouped structure.Coverage
torch.autograd.gradcheckRun tests:
Performance
Benchmarks on RTX 5070 Ti (16GB, CUDA 13.1).
Performance is reported without aggressive autotuning.
Autotuning for cuTile only covers tile sizes (TILE_M, TILE_N); while
num_ctasandoccupancyare left to compiler defaults.Comparison against PyTorch SDPA backends:
SDPBackend.FLASH_ATTENTION- PyTorch's Flash Attention backend (cuDNN-based)SDPBackend.EFFICIENT_ATTENTION- Memory-efficient attention (xFormers-based)SDPBackend.MATH- Standard PyTorch math implementation (no fusion)Forward (d=64, causal=True) - float16 (TFLOPS)
Forward (d=64, causal=False) - float16 (TFLOPS)
Backward (d=64, causal=True) - float16 (TFLOPS)
Backward (d=64, causal=False) - float16 (TFLOPS)
Forward + Backward (d=64, causal=True) - float16 (TFLOPS)
Forward + Backward (d=64, causal=False) - float16 (TFLOPS)
Forward (d=128, causal=True) - float16 (TFLOPS)
Forward (d=128, causal=False) - float16 (TFLOPS)
Backward (d=128, causal=True) - float16 (TFLOPS)
Backward (d=128, causal=False) - float16 (TFLOPS)
Forward + Backward (d=128, causal=True) - float16 (TFLOPS)
Forward + Backward (d=128, causal=False) - float16 (TFLOPS)
Run benchmarks:
Additional Comparison with Flash Attention in Triton (d=64, causal=True) - float16 (TFLOPS)
Reference: Fused Attention in Triton
Forward
Backward
Forward + Backward
Additional Comparison with Flash Attention in Triton (d=64, causal=False) - float16 (TFLOPS)
Forward
Backward
Forward + Backward
Usage
CI Configuration
Checklist
./format.sh)