Commit 9f7adfb
authored
[Fmha] Add head_dim=512 support for trtllm attention kernels (#2959)
Add support for `head_dim=512` in the trtllm FMHA kernel selection.
### Changes
- Add SDPA-based reference implementation for `head_dim > 256` in tests
(FlashInfer FA2/FA3 kernels don't support `head_dim > 256`)
- Add `test_trtllm_batch_prefill_head_dim_512` and
`test_trtllm_batch_decode_head_dim_512` covering BF16, FP16, and FP8
dtypes
### Follow-up
- NVFP4 coverage at `head_dim=512` is deferred to a follow-up PR.
Signed-off-by: Duncan Moss <djm.moss@gmail.com>1 parent 6ddbdb0 commit 9f7adfb
5 files changed
Lines changed: 295 additions & 12 deletions
File tree
- csrc
- flashinfer
- include/flashinfer/trtllm/fmha
- tests/attention
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
84 | | - | |
| 84 | + | |
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
| |||
361 | 361 | | |
362 | 362 | | |
363 | 363 | | |
364 | | - | |
| 364 | + | |
365 | 365 | | |
366 | 366 | | |
367 | 367 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
135 | 135 | | |
136 | 136 | | |
137 | 137 | | |
138 | | - | |
| 138 | + | |
139 | 139 | | |
140 | 140 | | |
141 | 141 | | |
| |||
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
158 | | - | |
| 158 | + | |
159 | 159 | | |
160 | 160 | | |
161 | 161 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
191 | 191 | | |
192 | 192 | | |
193 | 193 | | |
194 | | - | |
| 194 | + | |
195 | 195 | | |
196 | 196 | | |
197 | 197 | | |
| |||
789 | 789 | | |
790 | 790 | | |
791 | 791 | | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
792 | 800 | | |
793 | 801 | | |
794 | 802 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
169 | 169 | | |
170 | 170 | | |
171 | 171 | | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
172 | 175 | | |
173 | 176 | | |
174 | 177 | | |
175 | 178 | | |
176 | 179 | | |
177 | | - | |
178 | | - | |
179 | 180 | | |
180 | 181 | | |
181 | 182 | | |
| |||
854 | 855 | | |
855 | 856 | | |
856 | 857 | | |
857 | | - | |
| 858 | + | |
858 | 859 | | |
859 | 860 | | |
860 | 861 | | |
| |||
0 commit comments