Skip to content

Commit 701ebe0

Browse files
micmelesseazaidyChi-Chu319
authored
[AMD] Triton Backend for ROCm #3 (#2178)
* Fused Bwd (#137) * Fused with Good perf and stride fixed Fix fused bugs isolate failing case fix bug bring back test cases rm split impl in fused use exp2 is global variable now try oom fix save make fused the default limit to reproduce failure return default to split fix head size bug use exp2 back to true * new grid * BLK_SLICE_FACTOR = 1 * add tflops * new commit * test in parrallel * strides added by jusson * disable alibi * fix bugs again * default to fused * add bwd options for varlen * backend filter * default to jingning and batch 4 * best fwd config * fix TRITON_PRINT_AUTOTUNING flag bug * tune * Tuning fwd prefill * add if else * use flag * Minor mask fix * FLIP GRID * use best config for default * print when autotuning * test bfloat16 * fix k and v stride bugs * skip bfloat16 * test kvpacked * disable internal tests * pick default config based on arch * Add alibi in the new bwd kernel (#139) * enable alibi for jinging kernel enable alibi for jinging kernel match * save bad configs * fix alibi and causal bug * disable autotune by default * auto tune when benching is good * set best config * remove env var * Update amd_tests.yml * upgrad to triton==3.3.0 * increase shm * use 64 x 64 for now * save * handle 1d alibi * Add fp8 to fused kernel (#140) * fp8 stuff find test case compute delta fp8 basic fp8 config passing non causal path works * isolate bad case * fix fp8 bug * didnot fix fp8 bug * back to failing test * fp8 tests passing * skip * skip ref tests --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> * head, seq, batch (#141) * Fix keys (#144) * save * rm keys * fix keys * use GHA_RENDER_DEVICES * normal docker * Pad LSE (#148) * add round multiple * fix fwd * backward fix * use rounded lse flag * passing ROUNDED_LSE * default is new rounded mode * rename to fused_atmoics and fused_no_atomics * add test for torch_compile * add varlen torch compile test * add old one kernel for ref * fix varlen mismatch bug * fix shape issue in varlen but mismatch * sync torch compile kernel launch * simple varlen test * add debug code * rm old * ignore old impls * DEBUG flag works in interface only * ref uses the righ shape for lse * rm oldest bwd kernel * fix typo * fix varlen bug * fix bug. Get info from q for now * simple shape and stride checkout * add more tests * test kvcache * kvcache safe * match case * fix segfault due to bad return_softmax * run bench * run seperate for the main functions * just output benchmark * default csv format and time stamp files * non verbsoe bench * Sliding Window Forward (#151) * Compress SWA work test case set up debug inputs add fwd ref one mask ref fwd first pass save ref doesnot work for bigger seqlens save new version some causal cases failing found bad cases working new attn new atten works new attn_fwd works reorg n_extra_tokens use seqlen_delta_qk ref fwd works add sliding window to bwd ref test kvcache decode ref work with everything except sliding window add debug code for 12 failing sliding window cases for decode attention_decode_forward_ref_impl mostly works except for alibi fix alibi in attention_decode_forward_ref_impl ref works with normal, varlen & kvcache move stuff around figure out masking old attn inner two inner functions remove load_fn do Lk - Lq like ref unify IS_CAUSAL code in epilogue clean up add args rm inference stuff simplify compute_masking simpler compute mask stub out returning front masking variables remove pointer pass compute ptrs inloop compute block min and max window stub inside inner mask loop trying to use attn_fwd_mask causes issues fix compiler bug when front masking gen specifc types add sliding window and debug statements use identity for v add more taste cases add comments save use k_max_token for clarity disable debug configs basic NON-CAUSAL SLIDING WINDOW non causal sliding window works on the all the shapes non sliding window working in fwd clean up fused bwd seperate old fwd_prefill move configs to utils.py * fix bwd ref bug * skip local cases so that fa output * no sliding window causal green * add backward test skip for sliding window * clean reduce in fwd_kvcache. no is_CASUAL branching * add kvcache masking * kvcache working * fix some bugs in test.py * clean up * Fix Device Segfault (#152) * Compress segfault work fix backward segfault rework offset ignore .profile ignore .analysis save * assert the kernel launch device and tensor devices are the same * fix failing asserts * add asserts to fwd * Fix SDMASK bug * Log triton, torch and fa version * Fix fp8 import issues * fix docs (#154) * Sliding Window block classification logic (#155) * add aiter code * remove aiter stuff * sliding window non causal masking works * causal and sliding window block masking * extract common * clean up typo * helper for swa * ignore .amd * fix last block bug * Enable FA V3 (#157) * Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner * AITER integration (#159) * clean up v2 interface * assert fp8 scale shapes * rotary working * move rotary to impl layers * remove einops * enable rotarry in v3 * create interface * fix descale assert * unify bwd * lint from aiter * clean fp8 api * add api change * assert shapes for v2 * remove ref and bench.py * remove metadata class and clean up * bwd_prefill * one bwd.py * rename * lint * add bwd_change (#156) * Tune FP8 Perf (#160) * check cu count for gfx942 * create get_cu_count * update repo root * update forward tune * clean up load * use float8_e4m3fnuz * save * show bwd mode * recommend fp8 * use torch.float32 for fp8 kernel * add both best fp16 and fp8 config * tune fp8 backward * descale factors should be b, hk * fp8 bwd working on all primus configs * tune bwd configs * fa v3 tests passing * better warning * clean up bwd launcher * v3 passing * tune more * improve perf * clean up * lint * clean * start tuning gfx950 * tune non causal path * fix bug * save * Skip configs where BLOCK_M2 % BLOCK_N2 != 0 * skip more * stop tuning * fix varlen bug * fix dropout & causal/swa segfault * update the to machine new changes * save * fix more bugs * remove random seed * clean up * update readme * print tensor stats for debug * disable sliding window tests * add rdna configs * fix k partial bug * fix block_size_n bug * fix type check bug --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
1 parent 99589e5 commit 701ebe0

28 files changed

+10871
-13189
lines changed

README.md

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -129,74 +129,47 @@ FlashAttention-2 ROCm CK backend currently supports:
129129
3. Both forward's and backward's head dimensions up to 256.
130130

131131
#### Triton Backend
132-
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
132+
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.
133133

134-
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
135-
136-
These features are supported in Fwd and Bwd
137-
1) Fwd and Bwd with causal masking
138-
2) Variable sequence lengths
139-
3) Arbitrary Q and KV sequence lengths
140-
4) Arbitrary head sizes
141-
5) Multi and grouped query attention
142-
6) Dropout
143-
7) Rotary embeddings
144-
8) ALiBi
145-
146-
We are working on the following things
147-
1) Paged Attention
148-
2) Sliding Window
149-
3) FP8
150-
4) Performance Improvements
151-
152-
##### Getting Started
153-
To get started with the triton backend for AMD, follow the steps below.
154-
155-
First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed.
156-
157-
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
158-
159-
```
134+
To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention:
135+
```sh
136+
pip install triton==3.5.1
160137
cd flash-attention
161138
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
162139
```
163140

164-
To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
165-
```
141+
To run the tests (note: full suite takes hours):
142+
```sh
166143
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
167144
```
168145

169-
You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
170-
```
171-
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
172-
```
146+
For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`.
173147

174-
###### Docker
175-
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
176-
```
148+
For a quick start with Docker:
149+
```dockerfile
177150
FROM rocm/pytorch:latest
178151

179152
WORKDIR /workspace
180153

181-
# install flash attention
182-
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
154+
# install triton
155+
RUN pip install triton==3.5.1
183156

184-
RUN git clone https://github.com/ROCm/flash-attention.git &&\
157+
# build flash attention with triton backend
158+
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
185159
cd flash-attention &&\
186-
python setup.py install
160+
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
187161

188162
# set working dir
189163
WORKDIR /workspace/flash-attention
190-
```
191164

192-
To build the docker file
193-
```
194-
docker build -t fa_triton .
165+
# set env variable to use triton backend
166+
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
195167
```
196168

197-
To run the docker image
198-
```
199-
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
169+
Build and run:
170+
```sh
171+
docker build -t flash-attn-triton .
172+
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton
200173
```
201174

202175
## How to use FlashAttention

flash_attn/flash_attn_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# We need to import the CUDA kernels after importing torch
1111
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
1212
if USE_TRITON_ROCM:
13-
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
13+
from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
1414
else:
1515
import flash_attn_2_cuda as flash_attn_gpu
1616

flash_attn/flash_attn_triton_amd/Dockerfile

Lines changed: 0 additions & 17 deletions
This file was deleted.

flash_attn/flash_attn_triton_amd/README.md

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import interface_v2 as flash_attn_2
2+
from . import interface_v3 as flash_attn_3
3+
4+
__all__ = ["flash_attn_2", "flash_attn_3"]

0 commit comments

Comments
 (0)