Skip to content

[FEA] Add FlashAttention backward pass implementation, test cases and benchmark#49

Merged
xjmxyt merged 10 commits intoNVIDIA:mainfrom
Weili-0234:feat/flash-attn-bwd
Feb 11, 2026
Merged

[FEA] Add FlashAttention backward pass implementation, test cases and benchmark#49
xjmxyt merged 10 commits intoNVIDIA:mainfrom
Weili-0234:feat/flash-attn-bwd

Conversation

@Weili-0234
Copy link
Contributor

@Weili-0234 Weili-0234 commented Feb 5, 2026

Description

This PR adds backward pass support for FlashAttention, enabling training with the fused attention kernels. The implementation saves log-sum-exp (LSE) in the autograd context during forward pass (therefore being compatible with PyTorch's autograd). The backward pass computes gradients dQ, dK, dV using a standard Flash Attention backward algorithm.

Related Issue

NVIDIA/cutile-python#15

Implementation

Code Changes

attention.py: Added training-mode forward and backward kernels

  • fmha_fwd_kernel_with_lse: Forward kernel that saves LSE to context for backward
  • fmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO)
  • fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA support
  • fmha_bwd_dq_kernel: Computes dQ gradient
  • FlashAttentionFunction: torch.autograd.Function wrapper
  • tile_fmha_with_backward: Public API for training with autograd support
  • tile_fmha_functional: Auto-selects inference or training mode

Gradient Math

Forward:
  O = softmax(Q @ K^T / sqrt(d)) @ V
  LSE = logsumexp(Q @ K^T / sqrt(d))

Backward:
  Delta = rowsum(O * dO)
  dS = P * (dO @ V^T - Delta)
  dQ = dS @ K
  dK = dS^T @ Q
  dV = P^T @ dO

Testing

Correctness Verification

Similar to tests/ops/test_rmsnorm_backward.py, test classes inherit from common.PyTestCase and methods are named test_op_*. The reference_backward() helper computes reference gradients via torch.nn.functional.scaled_dot_product_attention with autograd, then each test_op_* method uses torch.testing.assert_close() to compare cuTile implementation's dQ/dK/dV against the PyTorch reference. GQA tests additionally expand K/V heads and sum gradients to match the grouped structure.

Coverage

  • Regular shapes (power of 2): various batch/head/seq/dim combinations
  • Irregular shapes (non-power of 2): odd/prime sequence lengths (127, 129, 131, 251, 1023, 1025, 2047, 2049)
  • Corner cases: single batch, single head, many heads, short sequences
  • GQA ratios: 2:1, 4:1, 8:1 (Llama-style), Multi-Query Attention
  • Numerical gradient check via torch.autograd.gradcheck
  • Test both float16 and bfloat16

Run tests:

pytest tests/ops/test_attention_backward.py -v

Performance

Benchmarks on RTX 5070 Ti (16GB, CUDA 13.1).
Performance is reported without aggressive autotuning.
Autotuning for cuTile only covers tile sizes (TILE_M, TILE_N); while num_ctas and occupancy are left to compiler defaults.

Comparison against PyTorch SDPA backends:

  • SDPA-Flash: SDPBackend.FLASH_ATTENTION - PyTorch's Flash Attention backend (cuDNN-based)
  • SDPA-MemEff: SDPBackend.EFFICIENT_ATTENTION - Memory-efficient attention (xFormers-based)
  • SDPA-Math: SDPBackend.MATH - Standard PyTorch math implementation (no fusion)

Forward (d=64, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 27.68 25.00 22.31 2.26
512 50.15 50.30 31.96 1.89
1024 71.00 69.72 38.57 1.98
2048 81.76 81.00 42.37 1.99
4096 86.33 86.30 43.90 NaN
8192 87.87 89.27 44.85 NaN

Forward (d=64, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 45.96 50.11 34.92 4.96
512 70.18 76.29 41.74 4.47
1024 80.46 88.02 44.05 4.72
2048 85.02 89.51 45.08 4.94
4096 85.56 91.31 45.42 NaN
8192 85.82 91.56 45.76 NaN

Backward (d=64, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 11.46 22.13 15.75 4.73
512 27.31 38.56 21.89 4.98
1024 41.16 52.33 26.90 5.34
2048 47.78 65.75 30.52 5.59
4096 51.68 76.10 32.49 NaN
8192 52.49 83.19 32.34 NaN

Backward (d=64, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 10.03 40.12 23.85 9.28
512 43.50 57.55 29.52 9.92
1024 49.24 69.77 32.55 10.61
2048 53.30 77.09 34.71 11.08
4096 55.02 83.15 35.43 NaN
8192 55.72 86.25 34.26 NaN

Forward + Backward (d=64, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 6.51 24.70 17.47 3.72
512 23.70 43.28 23.86 3.41
1024 46.17 56.64 29.18 3.60
2048 53.20 68.43 32.78 3.68
4096 57.45 78.12 34.71 NaN
8192 59.18 84.50 35.07 NaN

Forward + Backward (d=64, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 13.22 46.13 26.75 7.54
512 42.78 63.88 32.01 7.38
1024 54.94 74.18 34.96 7.82
2048 58.59 79.61 36.88 8.16
4096 60.72 84.83 37.67 NaN
8192 61.84 87.43 36.62 NaN

Forward (d=128, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 35.62 30.91 19.80 2.77
512 63.01 51.44 26.27 2.77
1024 79.09 66.63 30.93 2.97
2048 86.85 76.43 33.90 3.03
4096 90.06 81.41 35.51 NaN
8192 91.58 84.43 36.39 NaN

Forward (d=128, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 62.48 53.71 29.76 5.90
512 82.95 69.00 33.83 6.25
1024 89.24 79.89 35.31 6.75
2048 92.31 82.29 36.36 7.11
4096 92.19 84.61 36.78 NaN
8192 92.88 85.60 37.04 NaN

Backward (d=128, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 25.33 27.18 17.17 5.52
512 34.90 39.77 22.15 6.13
1024 44.22 54.43 25.98 6.72
2048 50.53 68.44 28.30 7.06
4096 54.73 78.74 29.41 NaN
8192 57.20 85.15 30.19 NaN

Backward (d=128, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 31.11 43.23 24.22 10.99
512 41.88 57.38 26.96 12.16
1024 50.02 69.04 28.74 13.37
2048 53.64 78.24 29.84 14.04
4096 57.95 83.68 30.31 NaN
8192 59.05 87.06 30.68 NaN

Forward + Backward (d=128, causal=True) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 13.05 29.98 18.56 4.32
512 39.39 42.94 23.44 4.56
1024 49.56 57.03 27.18 4.93
2048 56.15 69.29 29.38 5.10
4096 61.08 78.85 30.85 NaN
8192 63.88 84.80 31.68 NaN

Forward + Backward (d=128, causal=False) - float16 (TFLOPS)

N_CTX CuTile SDPA-Flash SDPA-MemEff SDPA-Math
256 40.49 48.16 26.40 8.84
512 48.31 60.87 28.86 9.60
1024 56.00 70.81 30.34 10.43
2048 59.91 78.67 31.20 10.97
4096 64.59 83.62 31.87 NaN
8192 65.80 86.49 32.21 NaN

Run benchmarks:

python tests/benchmark/bench_attention_backward.py

Additional Comparison with Flash Attention in Triton (d=64, causal=True) - float16 (TFLOPS)

Reference: Fused Attention in Triton

Forward

N_CTX CuTile Triton
256 27.60 29.31
512 49.98 49.70
1024 70.84 66.19
2048 81.71 75.35
4096 86.01 79.68
8192 87.67 82.25

Backward

N_CTX CuTile Triton
256 4.78 18.43
512 31.92 28.03
1024 41.23 37.79
2048 47.80 44.76
4096 51.34 48.94
8192 52.50 50.60

Forward + Backward

N_CTX CuTile Triton
256 8.93 11.78
512 33.11 21.68
1024 46.28 43.16
2048 53.06 49.74
4096 57.43 54.24
8192 58.82 56.71

Additional Comparison with Flash Attention in Triton (d=64, causal=False) - float16 (TFLOPS)

Forward

N_CTX CuTile Triton
256 45.45 52.42
512 70.02 70.66
1024 80.44 83.22
2048 84.80 86.94
4096 85.45 90.14
8192 85.82 91.27

Backward

N_CTX CuTile Triton
256 23.77 31.18
512 43.37 37.45
1024 49.16 42.45
2048 53.34 47.66
4096 55.03 49.97
8192 55.73 50.99

Forward + Backward

N_CTX CuTile Triton
256 16.82 26.18
512 49.00 44.56
1024 54.83 49.44
2048 58.65 54.10
4096 60.71 56.77
8192 61.81 58.25

Usage

from tilegym.ops.cutile.attention import tile_fmha_with_backward

q = torch.randn(B, H, S, D, requires_grad=True, device="cuda", dtype=torch.float16)
k = torch.randn(B, H, S, D, requires_grad=True, device="cuda", dtype=torch.float16)
v = torch.randn(B, H, S, D, requires_grad=True, device="cuda", dtype=torch.float16)

out = tile_fmha_with_backward(q, k, v, scaling=1.0/math.sqrt(D), is_causal=True)
loss = out.sum()
loss.backward()

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

- Add fmha_fwd_kernel_with_lse: Forward kernel that saves log-sum-exp (LSE)
  to context for backward pass, compatible with torch.autograd
- Add fmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO)
- Add fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA support
- Add fmha_bwd_dq_kernel: Computes dQ gradient
- Add FlashAttentionFunction: torch.autograd.Function wrapper that saves
  LSE during forward and uses it in backward
- Add tile_fmha_with_backward: Public API for training with autograd support
- Add tile_fmha_functional: Auto-selects inference or training mode
- Add autotune configs for backward kernels
- Add Test_FMHA_Backward class with test cases for:
  - Regular shapes (power of 2): various batch/head/seq/dim combinations
  - Irregular shapes (non-power of 2): odd/prime sequence lengths
  - Corner cases: single batch, single head, many heads, short sequences
  - Forward with LSE matches reference
  - Functional API inference and training modes
- Add Test_FMHA_Backward_GQA class for Grouped Query Attention tests:
  - GQA ratios: 2:1, 4:1, 8:1 (Llama-style)
  - Multi-Query Attention (all Q heads share 1 KV head)
- Add Test_FMHA_Backward_Numerical for torch.autograd.gradcheck
- All test functions follow test_op_* naming convention
- Add bench_attention_backward.py: Benchmarks fwd/bwd/fwd+bwd modes
  - Compares CuTile against PyTorch SDPA backends (Flash, MemEff, Math)
  - Tests d=64 and d=128 head dimensions
  - Tests causal and non-causal modes
  - Plot names include -GBps suffix for CI recognition
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@qelk123
Copy link

qelk123 commented Feb 5, 2026

Hi @Weili-0234, thank you for this great contribution! The current performance data looks promising. Could you clarify whether these numbers are from the causal or non-causal kernel? To get a fuller picture, could you provide a more comprehensive comparison against the Triton tutorial's Fused Attention? Specifically, it would be helpful to see the benchmarks for: d={64, 128} and Causal={True, False}.

@hannahli-nv
Copy link
Collaborator

/ok to test ac49a28

@Weili-0234
Copy link
Contributor Author

Hi @Weili-0234, thank you for this great contribution! The current performance data looks promising. Could you clarify whether these numbers are from the causal or non-causal kernel? To get a fuller picture, could you provide a more comprehensive comparison against the Triton tutorial's Fused Attention? Specifically, it would be helpful to see the benchmarks for: d={64, 128} and Causal={True, False}.

Thank you @qelk123 for the kind suggestion, I have updated the reported results accordingly, clearly denoting causality.

Also I notice the test_ops CI failed due to autotuning taking too long (exceeding the 10 min limit), should I just submit a version with less tuning space? But this would make the current implementation more under-tuned and less performant. What do you suggest?

As for comparison with triton, I can share a self-contained script below if you think it's helpful. The only reason I didn't contain the official triton tutorial version of flash attention is that I dont want to introduce extra file dependency into the TileGym repo.

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 9, 2026

Hi @Weili-0234, thanks for the comparison data! We’d be happy to accept your contribution once the CI passes.

To resolve the test-op CI timeout, please implement the following:

Introduce DISABLE_AUTOTUNE: Add this environment variable to your implementation. When enabled, it should bypass the autotuning process and default to the 0th configuration.

CI Configuration: Enable DISABLE_AUTOTUNE for the test_op pytest suite to save time, but do not enable it for benchmark to ensure performance results remain accurate.

Looking forward to the update!

@Weili-0234
Copy link
Contributor Author

Hi @xjmxyt, thanks for the kind suggestion.

I just implemented the DISABLE_AUTOTUNE environment variable as requested. When set, it defaults to the zeroth config. I also updated the CI config so that DISABLE_AUTOTUNE is enabled for the test_op suite but not for the benchmark.

Would you be able to trigger the CI again? Happy to address anything else if needed. Thanks!

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 10, 2026

/ok to test 0668aae

@Weili-0234
Copy link
Contributor Author

Hi, thanks for re-triggering the CI!

I'll fix the formatting issue in my next commit.

Regarding the test-ops timeout failure: I checked the logs and found that the CI timed out while still running existing tests like Test_BMM_FWD and my flash attn backward tests never got to run. Currently I only added the DISABLE_AUTOTUNE environment to cover the flash attn backward kernel but not to the existing kernels. Should I apply DISABLE_AUTOTUNE to other kernels as well, or do you have any other suggestion for fitting within the 10 minute limit?

Thanks for your patience in advance!

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 10, 2026

@Weili-0234 I see the logs, I think your tests has run.
([gw4] PASSED tests/ops/test_attention_backward.py::Test_FMHA_Backward::test_op_backward_irregular_shapes[cutile-2-8-2047-128-True-torch.bfloat16] ).
Could you please firstly reduce the test cases in your MR? For example, for each test_op_xxx keeping two test cases.

@Weili-0234
Copy link
Contributor Author

Hi, thanks for your kind observation and suggestion. Sorry that I didn't look at the logs carefully.

I've reduced the number of test cases as requested. Would you be able to trigger the CI again? Happy to address anything else if needed. Thanks!

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 11, 2026

/ok to test dc93e6e

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 11, 2026

/ok to test eecde06

@xjmxyt xjmxyt enabled auto-merge (squash) February 11, 2026 02:29
@xjmxyt xjmxyt merged commit c6a1acd into NVIDIA:main Feb 11, 2026
11 checks passed
Weili-0234 added a commit to Weili-0234/TileGym that referenced this pull request Feb 12, 2026
… benchmark (NVIDIA#49)

* feat: Add Flash Attention backward pass implementation

- Add fmha_fwd_kernel_with_lse: Forward kernel that saves log-sum-exp (LSE)
  to context for backward pass, compatible with torch.autograd
- Add fmha_bwd_preprocess_kernel: Computes Delta = rowsum(O * dO)
- Add fmha_bwd_dkdv_kernel: Computes dK and dV gradients with GQA support
- Add fmha_bwd_dq_kernel: Computes dQ gradient
- Add FlashAttentionFunction: torch.autograd.Function wrapper that saves
  LSE during forward and uses it in backward
- Add tile_fmha_with_backward: Public API for training with autograd support
- Add tile_fmha_functional: Auto-selects inference or training mode
- Add autotune configs for backward kernels

* test: Add comprehensive tests for Flash Attention backward pass

- Add Test_FMHA_Backward class with test cases for:
  - Regular shapes (power of 2): various batch/head/seq/dim combinations
  - Irregular shapes (non-power of 2): odd/prime sequence lengths
  - Corner cases: single batch, single head, many heads, short sequences
  - Forward with LSE matches reference
  - Functional API inference and training modes
- Add Test_FMHA_Backward_GQA class for Grouped Query Attention tests:
  - GQA ratios: 2:1, 4:1, 8:1 (Llama-style)
  - Multi-Query Attention (all Q heads share 1 KV head)
- Add Test_FMHA_Backward_Numerical for torch.autograd.gradcheck
- All test functions follow test_op_* naming convention

* bench: Add Flash Attention backward benchmarks

- Add bench_attention_backward.py: Benchmarks fwd/bwd/fwd+bwd modes
  - Compares CuTile against PyTorch SDPA backends (Flash, MemEff, Math)
  - Tests d=64 and d=128 head dimensions
  - Tests causal and non-causal modes
  - Plot names include -GBps suffix for CI recognition

* add formatting via ./format.sh

* additional formatting

* Introduce DISABLE_AUTOTUNE as requested in PR NVIDIA#49 reviewers

* fix flash attn bwd formatting

* reduce test cases to avoid CI timeout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants