Skip to content

Conversation

@micmelesse
Copy link
Contributor

@micmelesse micmelesse commented Jan 14, 2026

This pr is a follow-up to #1203 and #1610. We added the following features

  • Flash Attention V3 initial support
  • Paged Attention support
  • FP8 Performance Tuning
  • Fused Backward Pass

We increase the number of passing tests from 30k to 60k. The test results for v2 test on MI350 is

image

We have some partial work on Sliding window which is not done yet. Furthermore we have cleaned up the package and fixed some bugs. We have also add a test for v3 at hopper/test_flash_attn_triton_amd.py. You can see the results for v3 tests in the image below.

image

micmelesse and others added 23 commits January 12, 2026 22:08
* Fused with Good perf and stride fixed

Fix fused bugs

isolate failing case

fix bug

bring back test cases

rm split impl in fused

use exp2 is global variable now

try oom fix

save

make fused the default

limit to reproduce failure

return default to split

fix head size bug

use exp2 back to true

* new grid

* BLK_SLICE_FACTOR = 1

* add tflops

* new commit

* test in parrallel

* strides added by jusson

* disable alibi

* fix bugs again

* default to fused

* add bwd options for varlen

* backend filter

* default to jingning and batch 4

* best fwd config

* fix TRITON_PRINT_AUTOTUNING flag bug

* tune

* Tuning fwd prefill

* add if else

* use flag

* Minor mask fix

* FLIP GRID

* use best config for default

* print when autotuning

* test bfloat16

* fix k and v stride bugs

* skip bfloat16

* test kvpacked

* disable internal tests

* pick default config based on arch

* Add alibi in the new bwd kernel (#139)

* enable alibi for jinging kernel

enable alibi for jinging kernel

match

* save bad configs

* fix alibi and causal bug

* disable autotune by default

* auto tune when benching is good

* set best config

* remove env var

* Update amd_tests.yml

* upgrad to triton==3.3.0

* increase shm

* use 64 x 64 for now

* save

* handle 1d alibi

* Add fp8 to fused kernel (#140)

* fp8 stuff

find test case

compute delta fp8

basic fp8 config passing

non causal path works

* isolate bad case

* fix fp8 bug

* didnot fix fp8 bug

* back to failing test

* fp8 tests passing

* skip

* skip ref tests

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
* save

* rm keys

* fix keys

* use GHA_RENDER_DEVICES

* normal docker
* add round multiple

* fix fwd

* backward fix

* use rounded lse flag

* passing ROUNDED_LSE

* default is new rounded mode

* rename to fused_atmoics and fused_no_atomics

* add test for torch_compile

* add varlen torch compile test

* add old one kernel for ref

* fix varlen mismatch bug

* fix shape issue in varlen but mismatch

* sync torch compile kernel launch

* simple varlen test

* add debug code

* rm old

* ignore old impls

* DEBUG flag works in interface only

* ref uses the righ shape for lse

* rm oldest bwd kernel

* fix typo

* fix varlen bug

* fix bug. Get info from q for now

* simple shape and stride checkout

* add more tests

* test kvcache

* kvcache safe

* match case

* fix segfault due to bad return_softmax

* run bench

* run seperate for the main functions

* just output benchmark

* default csv format and time stamp files

* non verbsoe bench
* Compress SWA work

test case

set up debug inputs

add fwd ref

one mask ref

fwd first pass

save

ref doesnot work for bigger seqlens

save new version

some causal cases failing

found bad cases

working new attn

new atten works

new attn_fwd works

reorg n_extra_tokens

use seqlen_delta_qk

ref fwd works

add sliding window to bwd ref

test kvcache

decode ref work with everything except sliding window

add debug code for 12 failing sliding window cases for decode

attention_decode_forward_ref_impl mostly works except for alibi

fix alibi in attention_decode_forward_ref_impl

ref works with normal, varlen & kvcache

move stuff around

figure out masking

old attn inner

two inner functions

remove load_fn

do Lk - Lq like ref

unify IS_CAUSAL code in epilogue

clean up

add args

rm inference stuff

simplify compute_masking

simpler compute mask

stub out returning front masking variables

remove pointer pass

compute ptrs inloop

compute block min and max

window stub inside inner mask loop

trying to use attn_fwd_mask causes issues

fix compiler bug when front masking

gen specifc types

add sliding window and debug statements

use identity for v

add more taste cases

add comments

save

use k_max_token for clarity

disable debug configs

basic NON-CAUSAL SLIDING WINDOW

non causal sliding window works on the all the shapes

non sliding window working in fwd

clean up fused bwd

seperate old fwd_prefill

move configs to utils.py

* fix bwd ref bug

* skip local cases so that fa output

* no sliding window causal green

* add backward test skip for sliding window

* clean reduce in fwd_kvcache. no is_CASUAL branching

* add kvcache masking

* kvcache working

* fix some bugs in test.py

* clean up
* Compress segfault work

fix backward segfault

rework offset

ignore .profile

ignore .analysis

save

* assert the kernel launch device and tensor devices are the same

* fix failing asserts

* add asserts to fwd
* add aiter code

* remove aiter stuff

* sliding window non causal masking works

* causal and sliding window block masking

* extract common

* clean up typo

* helper for swa

* ignore .amd

* fix last block bug
* Compress PA work

narrow pa test

ref works on most cases

inplace ref with new_kv

inplace paged attention

add pa ref

save pa

basic  paged works

save

fix swa + causal in pa. Also new_kv only on pa path

passing

build fa v3

import interface from fa v3

copy fa tests

use v3 api

clean up

rename to match old test

support different head sizes

remove fp8

basisc passing v3 cases

test_flash_attn_varlen_output v3 working

isolate bad case for kvcache

case passing

save

use decode is seqused/ cacheseql is given

use decode if not varlen

basci kvcache v3 working

kvcache enable more cases

detect kvcache case if seqused_q is non and sequese_k is not None

skip failing test

find fp8 failing case

mha fp8 works

fix fp8 MQA/GQA bug

clean up

more clean up

clean up more

don't need fp8 dead code

remove train code with fp8 stuff

fp8 working in kvcache

paged + fp8 seems to be working

new_kv allowed

* clean up

* skip hopper race test

* clean up more

* fix paged + alibi

* similar inner paged api

* unify _attn_fwd_inner
* clean up v2 interface

* assert fp8 scale shapes

* rotary working

* move rotary to impl layers

* remove einops

* enable rotarry in v3

* create interface

* fix descale assert

* unify bwd

* lint from aiter

* clean fp8 api

* add api change

* assert shapes for v2

* remove ref and bench.py

* remove metadata class and clean up

* bwd_prefill

* one bwd.py

* rename

* lint
* check cu count for gfx942

* create get_cu_count

* update repo root

* update forward tune

* clean up load

* use float8_e4m3fnuz

* save

* show bwd mode

* recommend fp8

* use torch.float32 for fp8 kernel

* add both best fp16 and fp8 config

* tune fp8 backward

* descale factors should be b, hk

* fp8 bwd working on all primus configs

* tune bwd configs

* fa v3 tests passing

* better warning

* clean up bwd launcher

* v3 passing

* tune more

* improve perf

* clean up

* lint

* clean

* start tuning gfx950

* tune non causal path

* fix bug

* save

* Skip configs where BLOCK_M2 % BLOCK_N2 != 0

* skip more

* stop tuning

* fix varlen bug

* fix dropout & causal/swa segfault
@micmelesse micmelesse marked this pull request as ready for review January 15, 2026 03:03
@micmelesse
Copy link
Contributor Author

@tridao Can you take a look at this?

@RegiaYoung
Copy link

I ran a fine-tuning program using #2178 on a 7900XT (gfx1100), and the results showed that its backward computation performance was slightly weaker than the Dao-AILab:main (~5-8%); furthermore, NaN values ​​appeared after a few training steps (this might be due to an issue with my offloading framework), but this did not occur when using the main branch or when using NV GPUs previously. I'm not sure if it's a problem with the unsupported GPU or the new Fused Backward Pass code.

@tridao
Copy link
Member

tridao commented Jan 16, 2026

The interface looks fine to me. Are there AMD folks who can review @rocking5566 ?

@micmelesse
Copy link
Contributor Author

I ran a fine-tuning program using #2178 on a 7900XT (gfx1100), and the results showed that its backward computation performance was slightly weaker than the Dao-AILab:main (~5-8%); furthermore, NaN values ​​appeared after a few training steps (this might be due to an issue with my offloading framework), but this did not occur when using the main branch or when using NV GPUs previously. I'm not sure if it's a problem with the unsupported GPU or the new Fused Backward Pass code.

I am looking into this. I will post here when I have something.

@rocking5566
Copy link
Contributor

rocking5566 commented Jan 18, 2026

The interface, test and setup.py looks good to me.
But I am not familiar with the detail of triton kernel
@tianwyan is it good to you?

@tianwyan
Copy link

The interface, test and setup.py looks good to me. But I am not familiar with the detail of triton kernel @tianwyan is it good to you?

Working with PR author. thanks!

@0xDELUXA
Copy link

0xDELUXA commented Jan 20, 2026

I was able to build Flash Attention V3 on RDNA4 gfx1200 Windows using TheRock ROCm/torch and triton-windows. Made the hopper\setup.py Windows/RDNA4-compatible w/ Claude.
My steps were:

python -m venv venv
.\venv\Scripts\Activate.ps1
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision rocm[devel]
rocm-sdk init
pip install triton-windows

I also needed the following environment variables in the venv:

$ROCM_ROOT = (rocm-sdk path --root).Trim()
$ROCM_BIN = (rocm-sdk path --bin).Trim()
$env:ROCM_HOME = $ROCM_ROOT
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"
$env:CC = "clang-cl"
$env:CXX = "clang-cl"
$env:DISTUTILS_USE_SDK = "1"
$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE"

Then:

cd hopper
pip install --no-build-isolation -v .

It built without any errors.

Here is the modified hopper\setup.py file:
https://gist.github.com/0xDELUXA/2341b59a0bb269a4af393422a76e4b09

I then benchmarked it with this script:
https://gist.github.com/0xDELUXA/c3f68ac1e493feafc53b0f5c637c744e
So:

$env:PYTHONPATH = $PWD
python FA-3_tests.py

The output is:

============================================================
  Summary
============================================================
Passed: 7/7
✓ PASS - Installation
✓ PASS - Device
✓ PASS - Basic Attention
✓ PASS - Causal Attention
✓ PASS - Data Types
✓ PASS - Head Dimensions
✓ PASS - Sequence Benchmark

All tests passed!

I've tried running test_flash_attn_triton_amd.py. It's a very extensive script, so I enabled verbose logging. A lot of tests pass, but some fail with errors like:

flash-attention\flash_attn\flash_attn_triton_amd\fwd_prefill.py:996:0: error: Failures have been detected while processing an MLIR pass pipeline
flash-attention\flash_attn\flash_attn_triton_amd\fwd_prefill.py:996:0: note: Pipeline failed while executing [`TritonAMDFoldTrueCmpI` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`

I think these are Windows-specific issues with Triton.

Is there any possibility to make FA-3 (and this PR) compatible with Windows/RDNA4?

Edit:
I just found out that the original hopper\setup.py also works and avoids errors when forcing the build with:
$env:FLASH_ATTENTION_FORCE_BUILD="TRUE"; pip install --no-build-isolation -v .
Those triton-windows issues have also been resolved.

@micmelesse
Copy link
Contributor Author

micmelesse commented Jan 21, 2026

@0xDELUXA

Thanks for trying the pr out. The errors you're hitting (TritonAMDFoldTrueCmpI pipeline failures) are in triton-windows version of the triton compiler, not in the Flash Attention kernels. I'd suggest filing an issue with triton-windows.

Windows support is out of scope for this PR since it would be a major change affecting the whole repo. If you'd like to pursue it, I'd recommend opening a separate issue to discuss with the flash attention maintainers.

In regards to RDNA4, we already have RDNA configs that should work on Linux once this pr is merged. More specific per-card tuning could come in future PRs.

@0xDELUXA
Copy link

0xDELUXA commented Jan 22, 2026

Thanks for trying the pr out. The errors you're hitting (TritonAMDFoldTrueCmpI pipeline failures) are in triton-windows version of the triton compiler, not in the Flash Attention kernels. I'd suggest filing an issue with triton-windows.

Thanks for the suggestion - the errors have already been patched.

Windows support is out of scope for this PR since it would be a major change affecting the whole repo. If you'd like to pursue it, I'd recommend opening a separate issue to discuss with the flash attention maintainers.

From my testing, I can say it's already working on Windows when built with $env:FLASH_ATTENTION_FORCE_BUILD="TRUE"; pip install --no-build-isolation -v .. No extra patches are needed, though some workarounds are required to avoid Windows-specific issues (e.g. 0xC00000FD stack overflows).

In regards to RDNA4, we already have RDNA configs that should work on Linux once this pr is merged. More specific per-card tuning could come in future PRs.

Got it. Looking forward to any RDNA4-related developments.

@tianwyan
Copy link

The interface, test and setup.py looks good to me. But I am not familiar with the detail of triton kernel @tianwyan is it good to you?

Working with PR author. thanks!

No consistent error found from my side even merged with my PR #2147 . I am okay with it.

@micmelesse
Copy link
Contributor Author

micmelesse commented Jan 23, 2026

@RegiaYoung Can you check the latest commit? I ran the full suite of tests on the latest commit. All the tests pass on a CDNA device (MI350) and almost all of them pass on a RDNA (W7800) device but there is some flakiness which seems due to lower level compiler issues. There is not much I can do about that the kernel level here but I will reach out to the people working on the triton compiler on AMD for a fix. But It will be a longer term effort.

@micmelesse
Copy link
Contributor Author

micmelesse commented Jan 23, 2026

The interface, test and setup.py looks good to me. But I am not familiar with the detail of triton kernel @tianwyan is it good to you?

Working with PR author. thanks!

No consistent error found from my side even merged with my PR #2147 . I am okay with it.

@tridao Can we merge this? There is some flakiness on RDNA devices but is not on the kernel level. I will inform people working the compiler team of the issue.

@tridao tridao merged commit 701ebe0 into Dao-AILab:main Jan 28, 2026
@micmelesse micmelesse deleted the main_perf_rebase branch January 28, 2026 22:08
elewarr pushed a commit to elewarr/flash-attention that referenced this pull request Feb 4, 2026
* Fused Bwd (Dao-AILab#137)

* Fused with Good perf and stride fixed

Fix fused bugs

isolate failing case

fix bug

bring back test cases

rm split impl in fused

use exp2 is global variable now

try oom fix

save

make fused the default

limit to reproduce failure

return default to split

fix head size bug

use exp2 back to true

* new grid

* BLK_SLICE_FACTOR = 1

* add tflops

* new commit

* test in parrallel

* strides added by jusson

* disable alibi

* fix bugs again

* default to fused

* add bwd options for varlen

* backend filter

* default to jingning and batch 4

* best fwd config

* fix TRITON_PRINT_AUTOTUNING flag bug

* tune

* Tuning fwd prefill

* add if else

* use flag

* Minor mask fix

* FLIP GRID

* use best config for default

* print when autotuning

* test bfloat16

* fix k and v stride bugs

* skip bfloat16

* test kvpacked

* disable internal tests

* pick default config based on arch

* Add alibi in the new bwd kernel (Dao-AILab#139)

* enable alibi for jinging kernel

enable alibi for jinging kernel

match

* save bad configs

* fix alibi and causal bug

* disable autotune by default

* auto tune when benching is good

* set best config

* remove env var

* Update amd_tests.yml

* upgrad to triton==3.3.0

* increase shm

* use 64 x 64 for now

* save

* handle 1d alibi

* Add fp8 to fused kernel (Dao-AILab#140)

* fp8 stuff

find test case

compute delta fp8

basic fp8 config passing

non causal path works

* isolate bad case

* fix fp8 bug

* didnot fix fp8 bug

* back to failing test

* fp8 tests passing

* skip

* skip ref tests

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>

* head, seq, batch (Dao-AILab#141)

* Fix keys (Dao-AILab#144)

* save

* rm keys

* fix keys

* use GHA_RENDER_DEVICES

* normal docker

* Pad LSE (Dao-AILab#148)

* add round multiple

* fix fwd

* backward fix

* use rounded lse flag

* passing ROUNDED_LSE

* default is new rounded mode

* rename to fused_atmoics and fused_no_atomics

* add test for torch_compile

* add varlen torch compile test

* add old one kernel for ref

* fix varlen mismatch bug

* fix shape issue in varlen but mismatch

* sync torch compile kernel launch

* simple varlen test

* add debug code

* rm old

* ignore old impls

* DEBUG flag works in interface only

* ref uses the righ shape for lse

* rm oldest bwd kernel

* fix typo

* fix varlen bug

* fix bug. Get info from q for now

* simple shape and stride checkout

* add more tests

* test kvcache

* kvcache safe

* match case

* fix segfault due to bad return_softmax

* run bench

* run seperate for the main functions

* just output benchmark

* default csv format and time stamp files

* non verbsoe bench

* Sliding Window Forward (Dao-AILab#151)

* Compress SWA work

test case

set up debug inputs

add fwd ref

one mask ref

fwd first pass

save

ref doesnot work for bigger seqlens

save new version

some causal cases failing

found bad cases

working new attn

new atten works

new attn_fwd works

reorg n_extra_tokens

use seqlen_delta_qk

ref fwd works

add sliding window to bwd ref

test kvcache

decode ref work with everything except sliding window

add debug code for 12 failing sliding window cases for decode

attention_decode_forward_ref_impl mostly works except for alibi

fix alibi in attention_decode_forward_ref_impl

ref works with normal, varlen & kvcache

move stuff around

figure out masking

old attn inner

two inner functions

remove load_fn

do Lk - Lq like ref

unify IS_CAUSAL code in epilogue

clean up

add args

rm inference stuff

simplify compute_masking

simpler compute mask

stub out returning front masking variables

remove pointer pass

compute ptrs inloop

compute block min and max

window stub inside inner mask loop

trying to use attn_fwd_mask causes issues

fix compiler bug when front masking

gen specifc types

add sliding window and debug statements

use identity for v

add more taste cases

add comments

save

use k_max_token for clarity

disable debug configs

basic NON-CAUSAL SLIDING WINDOW

non causal sliding window works on the all the shapes

non sliding window working in fwd

clean up fused bwd

seperate old fwd_prefill

move configs to utils.py

* fix bwd ref bug

* skip local cases so that fa output

* no sliding window causal green

* add backward test skip for sliding window

* clean reduce in fwd_kvcache. no is_CASUAL branching

* add kvcache masking

* kvcache working

* fix some bugs in test.py

* clean up

* Fix Device Segfault (Dao-AILab#152)

* Compress segfault work

fix backward segfault

rework offset

ignore .profile

ignore .analysis

save

* assert the kernel launch device and tensor devices are the same

* fix failing asserts

* add asserts to fwd

* Fix SDMASK bug

* Log triton, torch and fa version

* Fix fp8 import issues

* fix docs (Dao-AILab#154)

* Sliding Window block classification logic (Dao-AILab#155)

* add aiter code

* remove aiter stuff

* sliding window non causal masking works

* causal and sliding window block masking

* extract common

* clean up typo

* helper for swa

* ignore .amd

* fix last block bug

* Enable FA V3 (Dao-AILab#157)

* Compress PA work

narrow pa test

ref works on most cases

inplace ref with new_kv

inplace paged attention

add pa ref

save pa

basic  paged works

save

fix swa + causal in pa. Also new_kv only on pa path

passing

build fa v3

import interface from fa v3

copy fa tests

use v3 api

clean up

rename to match old test

support different head sizes

remove fp8

basisc passing v3 cases

test_flash_attn_varlen_output v3 working

isolate bad case for kvcache

case passing

save

use decode is seqused/ cacheseql is given

use decode if not varlen

basci kvcache v3 working

kvcache enable more cases

detect kvcache case if seqused_q is non and sequese_k is not None

skip failing test

find fp8 failing case

mha fp8 works

fix fp8 MQA/GQA bug

clean up

more clean up

clean up more

don't need fp8 dead code

remove train code with fp8 stuff

fp8 working in kvcache

paged + fp8 seems to be working

new_kv allowed

* clean up

* skip hopper race test

* clean up more

* fix paged + alibi

* similar inner paged api

* unify _attn_fwd_inner

* AITER integration (Dao-AILab#159)

* clean up v2 interface

* assert fp8 scale shapes

* rotary working

* move rotary to impl layers

* remove einops

* enable rotarry in v3

* create interface

* fix descale assert

* unify bwd

* lint from aiter

* clean fp8 api

* add api change

* assert shapes for v2

* remove ref and bench.py

* remove metadata class and clean up

* bwd_prefill

* one bwd.py

* rename

* lint

* add bwd_change (Dao-AILab#156)

* Tune FP8 Perf (Dao-AILab#160)

* check cu count for gfx942

* create get_cu_count

* update repo root

* update forward tune

* clean up load

* use float8_e4m3fnuz

* save

* show bwd mode

* recommend fp8

* use torch.float32 for fp8 kernel

* add both best fp16 and fp8 config

* tune fp8 backward

* descale factors should be b, hk

* fp8 bwd working on all primus configs

* tune bwd configs

* fa v3 tests passing

* better warning

* clean up bwd launcher

* v3 passing

* tune more

* improve perf

* clean up

* lint

* clean

* start tuning gfx950

* tune non causal path

* fix bug

* save

* Skip configs where BLOCK_M2 % BLOCK_N2 != 0

* skip more

* stop tuning

* fix varlen bug

* fix dropout & causal/swa segfault

* update the to machine new changes

* save

* fix more bugs

* remove random seed

* clean up

* update readme

* print tensor stats for debug

* disable sliding window tests

* add rdna configs

* fix k partial bug

* fix block_size_n bug

* fix type check bug

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants