Skip to content

[Feat] Add sparse attn to sgl-kernel #5327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD)

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

set(BUILD_FA3, OFF)

find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)

enable_language(CUDA)
Expand Down Expand Up @@ -80,7 +78,6 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
)

set(CMAKE_CXX_STANDARD 17)
Expand Down Expand Up @@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)

option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)


if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
Expand All @@ -127,7 +127,7 @@ else()
endif()

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(BUILD_FA3 ON)
set(SGL_KERNEL_ENABLE_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
Expand Down Expand Up @@ -187,11 +187,33 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
)

Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})

target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)

target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
)

install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")

# ============================ Optional Install ============================= #
# set flash-attention sources file
# BF16 source files
if (BUILD_FA3)
if (SGL_KERNEL_ENABLE_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
Expand Down Expand Up @@ -246,7 +268,9 @@ if (BUILD_FA3)
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})

target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_include_directories(flash_ops PRIVATE
${TORCH_INCLUDE_DIRS}
${repo-flash-attention_SOURCE_DIR}/hopper)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)

install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
Expand All @@ -260,14 +284,6 @@ if (BUILD_FA3)
)
endif()

Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})

target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)

install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")

# JIT Logic
# DeepGEMM

Expand Down
22 changes: 22 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);

/*
* From Sparse Flash Attention
*/
m.def(
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);

m.def(
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
}

REGISTER_EXTENSION(common_ops)
50 changes: 50 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,18 +256,21 @@ void min_p_sampling_from_probs(
double min_p_val,
bool deterministic,
int64_t cuda_stream);

void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream);

void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);

void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
Expand All @@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);

void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
Expand All @@ -288,3 +292,49 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);

namespace flash {
/*
* From fa2 sparse
*/
std::vector<at::Tensor> mha_fwd_sparse(
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
const double softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);

std::vector<at::Tensor> mha_varlen_fwd_sparse(
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
const at::Tensor& block_count,
const at::Tensor& block_offset,
const at::Tensor& column_count,
const at::Tensor& column_index,
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
const c10::optional<at::Tensor>&
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int64_t max_seqlen_q,
const int64_t max_seqlen_k,
const double p_dropout,
const double softmax_scale,
const bool zero_tensors,
bool is_causal,
const double softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
} // namespace flash
175 changes: 175 additions & 0 deletions sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.

Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out


def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.

Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
Loading
Loading