Skip to content

Bump Flashinfer to 0.2.5 #5870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ jobs:
uses: actions/checkout@v4

- name: Install dependencies
env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: |
bash scripts/ci_install_dependency.sh

Expand Down
2 changes: 1 addition & 1 deletion docs/start/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
4 changes: 2 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
"torchao>=0.7.0",
"torchao>=0.9.0",
"transformers==4.51.1",
"uvicorn",
"uvloop",
Expand All @@ -47,7 +47,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.0",
"flashinfer_python==0.2.3",
"flashinfer_python==0.2.5",
"torch==2.6.0",
"torchvision==0.21.0",
"cuda-python",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer":
assert_pkg_version(
"flashinfer_python",
"0.2.3",
"0.2.5",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
Expand Down
189 changes: 107 additions & 82 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

import torch

if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import torch._dynamo

torch._dynamo.config.suppress_errors = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ @Fridge003 @AkazaAkane Why do we need to set torch._dynamo.config.suppress_errors = True here?

Copy link
Collaborator Author

@Fridge003 Fridge003 Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise there will be error when setting --enable-torch-compile. But I don't really know the exact reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 Do u know the reason

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the suggestion from flashinfer. When I try to integrate it with a older version of torch, there is issue of allocating cuda graph flashinfer and it suggests to add this.


from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
Expand Down Expand Up @@ -82,8 +87,6 @@ def __init__(
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype

assert not (
model_runner.sliding_window_size is not None
Expand Down Expand Up @@ -268,6 +271,12 @@ def init_cuda_graph_state(
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]

# Ensure tensors are properly allocated
for i in range(self.num_wrappers):
# Force allocation by performing a small operation
if len(self.cuda_graph_kv_indices[i]) > 0:
self.cuda_graph_kv_indices[i][0] = 0

if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
Expand Down Expand Up @@ -396,8 +405,6 @@ def forward_extend(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
Expand All @@ -414,7 +421,7 @@ def forward_extend(
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

o = prefill_wrapper_paged.forward(
Expand All @@ -424,8 +431,8 @@ def forward_extend(
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
Expand All @@ -452,7 +459,7 @@ def forward_extend(

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand All @@ -466,8 +473,6 @@ def forward_decode(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
Expand All @@ -481,16 +486,17 @@ def forward_decode(
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

# Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand Down Expand Up @@ -1146,8 +1152,9 @@ def fast_decode_plan(
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -1163,6 +1170,18 @@ def fast_decode_plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0

# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"

if kv_data_type is None:
kv_data_type = q_data_type

if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

Expand All @@ -1178,85 +1197,91 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)

# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type

if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)
self.empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr_host.to(
self.device, non_blocking=non_blocking
)

# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
),
device=self.device,
)

empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, kv_data_type)
if isinstance(kv_data_type, str)
else kv_data_type
),
device=self.device,
)

indptr_host = (
global_override_indptr_cpu
if global_override_indptr_cpu is not None
else indptr.cpu()
)

if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(
indptr_host, self.last_page_len[:batch_size], page_size
)

self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
torch.cuda.current_stream().cuda_stream,
)
else:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
self.empty_q_data,
self.empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
)
with torch.cuda.device(self.device):

if self.use_tensor_cores:
# ALSO convert last_page_len to CPU
last_page_len_host = last_page_len.cpu()

kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)

try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")

self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
Expand Down
Loading
Loading