Skip to content

Commit 4879e50

Browse files
authored
[Feat] Add sparse attn to sgl-kernel (#5327)
1 parent bc92107 commit 4879e50

File tree

5 files changed

+625
-14
lines changed

5 files changed

+625
-14
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD)
55

66
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
77

8-
set(BUILD_FA3, OFF)
9-
108
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
119

1210
enable_language(CUDA)
@@ -80,7 +78,6 @@ include_directories(
8078
${repo-cutlass_SOURCE_DIR}/examples/common
8179
${repo-flashinfer_SOURCE_DIR}/include
8280
${repo-flashinfer_SOURCE_DIR}/csrc
83-
${repo-flash-attention_SOURCE_DIR}/hopper
8481
)
8582

8683
set(CMAKE_CXX_STANDARD 17)
@@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
115112
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
116113
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
117114

115+
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
116+
117+
118118
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
119119
list(APPEND SGL_KERNEL_CUDA_FLAGS
120120
"-gencode=arch=compute_100,code=sm_100"
@@ -127,7 +127,7 @@ else()
127127
endif()
128128

129129
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
130-
set(BUILD_FA3 ON)
130+
set(SGL_KERNEL_ENABLE_FA3 ON)
131131
list(APPEND SGL_KERNEL_CUDA_FLAGS
132132
"-gencode=arch=compute_90a,code=sm_90a"
133133
)
@@ -187,11 +187,33 @@ set(SOURCES
187187
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
188188
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
189189
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
190+
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
191+
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
192+
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
193+
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
194+
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
190195
)
191196

197+
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
198+
199+
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
200+
target_include_directories(common_ops PRIVATE
201+
${TORCH_INCLUDE_DIRS}
202+
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src)
203+
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
204+
205+
target_compile_definitions(common_ops PRIVATE
206+
FLASHATTENTION_DISABLE_BACKWARD
207+
FLASHATTENTION_DISABLE_DROPOUT
208+
FLASHATTENTION_DISABLE_UNEVEN_K
209+
)
210+
211+
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
212+
213+
# ============================ Optional Install ============================= #
192214
# set flash-attention sources file
193215
# BF16 source files
194-
if (BUILD_FA3)
216+
if (SGL_KERNEL_ENABLE_FA3)
195217
set(SGL_FLASH_KERNEL_CUDA_FLAGS
196218
"-DNDEBUG"
197219
"-DOPERATOR_NAMESPACE=sgl-kernel"
@@ -246,7 +268,9 @@ if (BUILD_FA3)
246268
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
247269

248270
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
249-
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
271+
target_include_directories(flash_ops PRIVATE
272+
${TORCH_INCLUDE_DIRS}
273+
${repo-flash-attention_SOURCE_DIR}/hopper)
250274
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
251275

252276
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
@@ -260,14 +284,6 @@ if (BUILD_FA3)
260284
)
261285
endif()
262286

263-
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
264-
265-
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
266-
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
267-
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
268-
269-
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
270-
271287
# JIT Logic
272288
# DeepGEMM
273289

sgl-kernel/csrc/common_extension.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
206206
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
207207
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
208208
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
209+
210+
/*
211+
* From Sparse Flash Attention
212+
*/
213+
m.def(
214+
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
215+
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
216+
"Tensor!? out, Tensor? alibi_slopes, "
217+
"float p_dropout, float softmax_scale, bool is_causal, "
218+
"float softcap, bool return_softmax, Generator? gen)"
219+
"-> Tensor[]");
220+
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
221+
222+
m.def(
223+
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
224+
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
225+
"Tensor!? out, Tensor cu_seqlens_q, "
226+
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
227+
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
228+
"bool is_causal, float softcap, bool return_softmax, "
229+
"Generator? gen) -> Tensor[]");
230+
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
209231
}
210232

211233
REGISTER_EXTENSION(common_ops)

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,18 +256,21 @@ void min_p_sampling_from_probs(
256256
double min_p_val,
257257
bool deterministic,
258258
int64_t cuda_stream);
259+
259260
void top_k_renorm_probs(
260261
at::Tensor probs,
261262
at::Tensor renorm_probs,
262263
std::optional<at::Tensor> maybe_top_k_arr,
263264
int64_t top_k_val,
264265
int64_t cuda_stream);
266+
265267
void top_p_renorm_probs(
266268
at::Tensor probs,
267269
at::Tensor renorm_probs,
268270
std::optional<at::Tensor> maybe_top_p_arr,
269271
double top_p_val,
270272
int64_t cuda_stream);
273+
271274
void top_k_top_p_sampling_from_probs(
272275
at::Tensor probs,
273276
at::Tensor uniform_samples,
@@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs(
279282
double top_p_val,
280283
bool deterministic,
281284
int64_t cuda_stream);
285+
282286
void top_p_sampling_from_probs(
283287
at::Tensor probs,
284288
at::Tensor uniform_samples,
@@ -288,3 +292,49 @@ void top_p_sampling_from_probs(
288292
double top_p_val,
289293
bool deterministic,
290294
int64_t cuda_stream);
295+
296+
namespace flash {
297+
/*
298+
* From fa2 sparse
299+
*/
300+
std::vector<at::Tensor> mha_fwd_sparse(
301+
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
302+
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
303+
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
304+
const at::Tensor& block_count,
305+
const at::Tensor& block_offset,
306+
const at::Tensor& column_count,
307+
const at::Tensor& column_index,
308+
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
309+
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
310+
const double p_dropout,
311+
const double softmax_scale,
312+
bool is_causal,
313+
const double softcap,
314+
const bool return_softmax,
315+
std::optional<at::Generator> gen_);
316+
317+
std::vector<at::Tensor> mha_varlen_fwd_sparse(
318+
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
319+
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
320+
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
321+
const at::Tensor& block_count,
322+
const at::Tensor& block_offset,
323+
const at::Tensor& column_count,
324+
const at::Tensor& column_index,
325+
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
326+
const at::Tensor& cu_seqlens_q, // b+1
327+
const at::Tensor& cu_seqlens_k, // b+1
328+
const c10::optional<at::Tensor>&
329+
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
330+
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
331+
int64_t max_seqlen_q,
332+
const int64_t max_seqlen_k,
333+
const double p_dropout,
334+
const double softmax_scale,
335+
const bool zero_tensors,
336+
bool is_causal,
337+
const double softcap,
338+
const bool return_softmax,
339+
c10::optional<at::Generator> gen_);
340+
} // namespace flash
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
def maybe_contiguous(x):
8+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
9+
10+
11+
def sparse_attn_func(
12+
q,
13+
k,
14+
v,
15+
block_count,
16+
block_offset,
17+
column_count,
18+
column_index,
19+
dropout_p=0.0,
20+
softmax_scale=None,
21+
causal=False,
22+
softcap=0.0, # 0.0 means deactivated
23+
alibi_slopes=None,
24+
deterministic=False,
25+
return_attn_probs=False,
26+
*,
27+
return_softmax_lse=False,
28+
out=None,
29+
):
30+
"""Compute attention with vertical and slash sparsity patterns.
31+
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
32+
block_count and block_offset for slash sparsity patterns, and
33+
column_count and column_index for vertical sparsity patterns.
34+
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
35+
36+
Arguments:
37+
q: (batch_size, seqlen, nheads, headdim)
38+
k: (batch_size, seqlen, nheads_k, headdim)
39+
v: (batch_size, seqlen, nheads_k, headdim)
40+
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
41+
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
42+
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
43+
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
44+
dropout_p: float. Dropout probability.
45+
softmax_scale: float. The scaling of QK^T before applying softmax.
46+
Default to 1 / sqrt(headdim).
47+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
48+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
49+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
50+
is added to the attention score of query i and key j.
51+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
52+
which is slightly slower and uses more memory. The forward pass is always deterministic.
53+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
54+
testing only. The returned probabilities are not guaranteed to be correct
55+
(they might not have the right scaling).
56+
Return:
57+
out: (batch_size, seqlen, nheads, headdim).
58+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
59+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
60+
normalization factor).
61+
"""
62+
if softmax_scale is None:
63+
softmax_scale = q.shape[-1] ** (-0.5)
64+
65+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
66+
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
67+
q,
68+
k,
69+
v,
70+
block_count,
71+
block_offset,
72+
column_count,
73+
column_index,
74+
out,
75+
alibi_slopes,
76+
dropout_p,
77+
softmax_scale,
78+
causal,
79+
softcap,
80+
return_attn_probs and dropout_p > 0,
81+
None,
82+
)
83+
return (out, softmax_lse) if return_softmax_lse else out
84+
85+
86+
def sparse_attn_varlen_func(
87+
q,
88+
k,
89+
v,
90+
block_count,
91+
block_offset,
92+
column_count,
93+
column_index,
94+
cu_seqlens_q,
95+
cu_seqlens_k,
96+
max_seqlen_q,
97+
max_seqlen_k,
98+
dropout_p=0.0,
99+
softmax_scale=None,
100+
causal=False,
101+
softcap=0.0, # 0.0 means deactivated
102+
alibi_slopes=None,
103+
deterministic=False,
104+
return_attn_probs=False,
105+
*,
106+
return_softmax_lse=False,
107+
out=None,
108+
):
109+
"""Compute attention with vertical and slash sparsity patterns.
110+
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
111+
block_count and block_offset for slash sparsity patterns, and
112+
column_count and column_index for vertical sparsity patterns.
113+
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
114+
115+
Arguments:
116+
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
117+
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
118+
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
119+
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
120+
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
121+
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
122+
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
123+
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
124+
of the sequences in the batch, used to index into q.
125+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
126+
of the sequences in the batch, used to index into kv.
127+
max_seqlen_q: int. Maximum query sequence length in the batch.
128+
max_seqlen_k: int. Maximum key sequence length in the batch.
129+
dropout_p: float. Dropout probability.
130+
softmax_scale: float. The scaling of QK^T before applying softmax.
131+
Default to 1 / sqrt(headdim).
132+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
133+
softcap: float. Anything > 0 activates softcapping attention.
134+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
135+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
136+
is added to the attention score of query i and key j.
137+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
138+
which is slightly slower and uses more memory. The forward pass is always deterministic.
139+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
140+
testing only. The returned probabilities are not guaranteed to be correct
141+
(they might not have the right scaling).
142+
Return:
143+
out: (total, nheads, headdim).
144+
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
145+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
146+
normalization factor).
147+
"""
148+
if softmax_scale is None:
149+
softmax_scale = q.shape[-1] ** (-0.5)
150+
151+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
152+
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
153+
q,
154+
k,
155+
v,
156+
block_count,
157+
block_offset,
158+
column_count,
159+
column_index,
160+
out,
161+
cu_seqlens_q,
162+
cu_seqlens_k,
163+
None,
164+
alibi_slopes,
165+
max_seqlen_q,
166+
max_seqlen_k,
167+
dropout_p,
168+
softmax_scale,
169+
False,
170+
causal,
171+
softcap,
172+
return_attn_probs and dropout_p > 0,
173+
None,
174+
)
175+
return (out, softmax_lse) if return_softmax_lse else out

0 commit comments

Comments
 (0)