Skip to content

Commit 0ab3f43

Browse files
authored
Cutlass MLA: Disable split kv due to NVIDIA/cutlass#2274 (#6101)
1 parent cec98f1 commit 0ab3f43

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

sgl-kernel/csrc/attention/cutlass_mla_kernel.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ typename T::Fmha::Arguments args_from_options(
151151
page_size},
152152
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
153153
hw_info,
154-
-1, // split_kv
154+
// TODO(trevor-m): Change split_kv back to -1 when
155+
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
156+
// perform worse with larger context length and smaller batch sizes.
157+
1, // split_kv
155158
nullptr, // is_var_split_kv
156159
};
157160
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute

sgl-kernel/tests/test_cutlass_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_cutlass_mla_decode(
6767
pack_factor = 128 // block_size
6868
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
6969

70-
q = torch.randn(bs, h_q, d)
70+
q = torch.randn(bs, h_q, d) * 100.0
7171
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
7272

7373
kv_cache = torch.randn(block_table.numel(), block_size, d)

0 commit comments

Comments
 (0)