Commit 701ebe0
* Fused Bwd (#137)
* Fused with Good perf and stride fixed
Fix fused bugs
isolate failing case
fix bug
bring back test cases
rm split impl in fused
use exp2 is global variable now
try oom fix
save
make fused the default
limit to reproduce failure
return default to split
fix head size bug
use exp2 back to true
* new grid
* BLK_SLICE_FACTOR = 1
* add tflops
* new commit
* test in parrallel
* strides added by jusson
* disable alibi
* fix bugs again
* default to fused
* add bwd options for varlen
* backend filter
* default to jingning and batch 4
* best fwd config
* fix TRITON_PRINT_AUTOTUNING flag bug
* tune
* Tuning fwd prefill
* add if else
* use flag
* Minor mask fix
* FLIP GRID
* use best config for default
* print when autotuning
* test bfloat16
* fix k and v stride bugs
* skip bfloat16
* test kvpacked
* disable internal tests
* pick default config based on arch
* Add alibi in the new bwd kernel (#139)
* enable alibi for jinging kernel
enable alibi for jinging kernel
match
* save bad configs
* fix alibi and causal bug
* disable autotune by default
* auto tune when benching is good
* set best config
* remove env var
* Update amd_tests.yml
* upgrad to triton==3.3.0
* increase shm
* use 64 x 64 for now
* save
* handle 1d alibi
* Add fp8 to fused kernel (#140)
* fp8 stuff
find test case
compute delta fp8
basic fp8 config passing
non causal path works
* isolate bad case
* fix fp8 bug
* didnot fix fp8 bug
* back to failing test
* fp8 tests passing
* skip
* skip ref tests
---------
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
* head, seq, batch (#141)
* Fix keys (#144)
* save
* rm keys
* fix keys
* use GHA_RENDER_DEVICES
* normal docker
* Pad LSE (#148)
* add round multiple
* fix fwd
* backward fix
* use rounded lse flag
* passing ROUNDED_LSE
* default is new rounded mode
* rename to fused_atmoics and fused_no_atomics
* add test for torch_compile
* add varlen torch compile test
* add old one kernel for ref
* fix varlen mismatch bug
* fix shape issue in varlen but mismatch
* sync torch compile kernel launch
* simple varlen test
* add debug code
* rm old
* ignore old impls
* DEBUG flag works in interface only
* ref uses the righ shape for lse
* rm oldest bwd kernel
* fix typo
* fix varlen bug
* fix bug. Get info from q for now
* simple shape and stride checkout
* add more tests
* test kvcache
* kvcache safe
* match case
* fix segfault due to bad return_softmax
* run bench
* run seperate for the main functions
* just output benchmark
* default csv format and time stamp files
* non verbsoe bench
* Sliding Window Forward (#151)
* Compress SWA work
test case
set up debug inputs
add fwd ref
one mask ref
fwd first pass
save
ref doesnot work for bigger seqlens
save new version
some causal cases failing
found bad cases
working new attn
new atten works
new attn_fwd works
reorg n_extra_tokens
use seqlen_delta_qk
ref fwd works
add sliding window to bwd ref
test kvcache
decode ref work with everything except sliding window
add debug code for 12 failing sliding window cases for decode
attention_decode_forward_ref_impl mostly works except for alibi
fix alibi in attention_decode_forward_ref_impl
ref works with normal, varlen & kvcache
move stuff around
figure out masking
old attn inner
two inner functions
remove load_fn
do Lk - Lq like ref
unify IS_CAUSAL code in epilogue
clean up
add args
rm inference stuff
simplify compute_masking
simpler compute mask
stub out returning front masking variables
remove pointer pass
compute ptrs inloop
compute block min and max
window stub inside inner mask loop
trying to use attn_fwd_mask causes issues
fix compiler bug when front masking
gen specifc types
add sliding window and debug statements
use identity for v
add more taste cases
add comments
save
use k_max_token for clarity
disable debug configs
basic NON-CAUSAL SLIDING WINDOW
non causal sliding window works on the all the shapes
non sliding window working in fwd
clean up fused bwd
seperate old fwd_prefill
move configs to utils.py
* fix bwd ref bug
* skip local cases so that fa output
* no sliding window causal green
* add backward test skip for sliding window
* clean reduce in fwd_kvcache. no is_CASUAL branching
* add kvcache masking
* kvcache working
* fix some bugs in test.py
* clean up
* Fix Device Segfault (#152)
* Compress segfault work
fix backward segfault
rework offset
ignore .profile
ignore .analysis
save
* assert the kernel launch device and tensor devices are the same
* fix failing asserts
* add asserts to fwd
* Fix SDMASK bug
* Log triton, torch and fa version
* Fix fp8 import issues
* fix docs (#154)
* Sliding Window block classification logic (#155)
* add aiter code
* remove aiter stuff
* sliding window non causal masking works
* causal and sliding window block masking
* extract common
* clean up typo
* helper for swa
* ignore .amd
* fix last block bug
* Enable FA V3 (#157)
* Compress PA work
narrow pa test
ref works on most cases
inplace ref with new_kv
inplace paged attention
add pa ref
save pa
basic paged works
save
fix swa + causal in pa. Also new_kv only on pa path
passing
build fa v3
import interface from fa v3
copy fa tests
use v3 api
clean up
rename to match old test
support different head sizes
remove fp8
basisc passing v3 cases
test_flash_attn_varlen_output v3 working
isolate bad case for kvcache
case passing
save
use decode is seqused/ cacheseql is given
use decode if not varlen
basci kvcache v3 working
kvcache enable more cases
detect kvcache case if seqused_q is non and sequese_k is not None
skip failing test
find fp8 failing case
mha fp8 works
fix fp8 MQA/GQA bug
clean up
more clean up
clean up more
don't need fp8 dead code
remove train code with fp8 stuff
fp8 working in kvcache
paged + fp8 seems to be working
new_kv allowed
* clean up
* skip hopper race test
* clean up more
* fix paged + alibi
* similar inner paged api
* unify _attn_fwd_inner
* AITER integration (#159)
* clean up v2 interface
* assert fp8 scale shapes
* rotary working
* move rotary to impl layers
* remove einops
* enable rotarry in v3
* create interface
* fix descale assert
* unify bwd
* lint from aiter
* clean fp8 api
* add api change
* assert shapes for v2
* remove ref and bench.py
* remove metadata class and clean up
* bwd_prefill
* one bwd.py
* rename
* lint
* add bwd_change (#156)
* Tune FP8 Perf (#160)
* check cu count for gfx942
* create get_cu_count
* update repo root
* update forward tune
* clean up load
* use float8_e4m3fnuz
* save
* show bwd mode
* recommend fp8
* use torch.float32 for fp8 kernel
* add both best fp16 and fp8 config
* tune fp8 backward
* descale factors should be b, hk
* fp8 bwd working on all primus configs
* tune bwd configs
* fa v3 tests passing
* better warning
* clean up bwd launcher
* v3 passing
* tune more
* improve perf
* clean up
* lint
* clean
* start tuning gfx950
* tune non causal path
* fix bug
* save
* Skip configs where BLOCK_M2 % BLOCK_N2 != 0
* skip more
* stop tuning
* fix varlen bug
* fix dropout & causal/swa segfault
* update the to machine new changes
* save
* fix more bugs
* remove random seed
* clean up
* update readme
* print tensor stats for debug
* disable sliding window tests
* add rdna configs
* fix k partial bug
* fix block_size_n bug
* fix type check bug
---------
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
1 parent 99589e5 commit 701ebe0
File tree
28 files changed
+10871
-13189
lines changed- flash_attn
- flash_attn_triton_amd
- hopper
- tests
28 files changed
+10871
-13189
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
132 | | - | |
| 132 | + | |
133 | 133 | | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
160 | 137 | | |
161 | 138 | | |
162 | 139 | | |
163 | 140 | | |
164 | | - | |
165 | | - | |
| 141 | + | |
| 142 | + | |
166 | 143 | | |
167 | 144 | | |
168 | 145 | | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
| 146 | + | |
173 | 147 | | |
174 | | - | |
175 | | - | |
176 | | - | |
| 148 | + | |
| 149 | + | |
177 | 150 | | |
178 | 151 | | |
179 | 152 | | |
180 | 153 | | |
181 | | - | |
182 | | - | |
| 154 | + | |
| 155 | + | |
183 | 156 | | |
184 | | - | |
| 157 | + | |
| 158 | + | |
185 | 159 | | |
186 | | - | |
| 160 | + | |
187 | 161 | | |
188 | 162 | | |
189 | 163 | | |
190 | | - | |
191 | 164 | | |
192 | | - | |
193 | | - | |
194 | | - | |
| 165 | + | |
| 166 | + | |
195 | 167 | | |
196 | 168 | | |
197 | | - | |
198 | | - | |
199 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
200 | 173 | | |
201 | 174 | | |
202 | 175 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
0 commit comments