Skip to content

Commit 5912af5

Browse files
committed
Fix conflicts and lints
1 parent e3bec63 commit 5912af5

File tree

5 files changed

+30
-51
lines changed

5 files changed

+30
-51
lines changed

.github/workflows/pr-test.yml

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,6 @@ jobs:
105105
cd test/srt
106106
python3 run_suite.py --suite per-commit-8-gpu
107107
108-
performance-test-1-gpu-part-1:
109-
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
110-
github.event.pull_request.draft == false
111-
runs-on: 8-gpu-runner
112-
steps:
113-
- name: Checkout code
114-
uses: actions/checkout@v4
115-
116-
- name: Install dependencies
117-
run: |
118-
bash scripts/ci_install_dependency.sh
119-
120-
- name: Run test
121-
timeout-minutes: 40
122-
run: |
123-
cd test/srt
124-
python3 run_suite.py --suite per-commit-8-gpu
125-
126108
performance-test-1-gpu-part-1:
127109
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
128110
github.event.pull_request.draft == false
@@ -132,8 +114,6 @@ jobs:
132114
uses: actions/checkout@v4
133115

134116
- name: Install dependencies
135-
env:
136-
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
137117
run: |
138118
bash scripts/ci_install_dependency.sh
139119

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
import torch._dynamo
18+
1819
torch._dynamo.config.suppress_errors = True
1920

2021

@@ -56,6 +57,7 @@ class PrefillMetadata:
5657
use_ragged: bool
5758
extend_no_prefix: bool
5859

60+
5961
# Reuse this workspace buffer across all flashinfer wrappers
6062
global_workspace_buffer = None
6163

@@ -282,7 +284,7 @@ def init_cuda_graph_state(
282284
)
283285
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
284286
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
285-
287+
286288
# Force allocation
287289
self.cuda_graph_custom_mask[0] = 0
288290
for i in range(len(self.cuda_graph_qk_indptr)):
@@ -291,7 +293,7 @@ def init_cuda_graph_state(
291293
for i in range(len(self.cuda_graph_qo_indptr)):
292294
if len(self.cuda_graph_qo_indptr[i]) > 0:
293295
self.cuda_graph_qo_indptr[i][0] = 0
294-
296+
295297
# Force synchronization to ensure all tensors are allocated
296298
torch.cuda.synchronize()
297299

@@ -508,11 +510,11 @@ def safe_forward_call(q, kv_cache):
508510
k_scale=layer.k_scale,
509511
v_scale=layer.v_scale,
510512
)
511-
513+
512514
# Call the wrapped function
513515
o = safe_forward_call(
514516
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
515-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
517+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
516518
)
517519

518520
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1185,7 +1187,7 @@ def fast_decode_plan(
11851187
batch_size = len(last_page_len)
11861188
if logits_soft_cap is None:
11871189
logits_soft_cap = 0.0
1188-
1190+
11891191
# Handle data types consistently
11901192
if data_type is not None:
11911193
if q_data_type is None:
@@ -1194,7 +1196,7 @@ def fast_decode_plan(
11941196
kv_data_type = data_type
11951197
elif q_data_type is None:
11961198
q_data_type = "float16"
1197-
1199+
11981200
if kv_data_type is None:
11991201
kv_data_type = q_data_type
12001202

@@ -1218,19 +1220,19 @@ def fast_decode_plan(
12181220
self._paged_kv_indices_buf = indices
12191221
self._paged_kv_last_page_len_buf = last_page_len
12201222
if self.use_tensor_cores:
1221-
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1223+
self._qo_indptr_buf = qo_indptr_host.to(
1224+
self.device, non_blocking=non_blocking
1225+
)
12221226

12231227
# Create empty tensors for dtype info if needed
12241228
empty_q_data = torch.empty(
12251229
0,
12261230
dtype=(
1227-
getattr(torch, q_data_type)
1228-
if isinstance(q_data_type, str)
1229-
else q_data_type
1231+
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
12301232
),
12311233
device=self.device,
12321234
)
1233-
1235+
12341236
empty_kv_cache = torch.empty(
12351237
0,
12361238
dtype=(
@@ -1248,7 +1250,7 @@ def fast_decode_plan(
12481250
)
12491251

12501252
with torch.cuda.device(self.device):
1251-
1253+
12521254
if self.use_tensor_cores:
12531255
# Convert indptr to CPU, as the authors intended
12541256
if global_override_indptr_cpu is not None:
@@ -1259,10 +1261,8 @@ def fast_decode_plan(
12591261
# ALSO convert last_page_len to CPU
12601262
last_page_len_host = last_page_len.cpu()
12611263

1262-
kv_lens_arr_host = get_seq_lens(
1263-
indptr_host, last_page_len_host, page_size
1264-
)
1265-
1264+
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1265+
12661266
try:
12671267
# Make sure we pass exactly 15 arguments for tensor core version
12681268
self._plan_info = self._cached_module.plan(
@@ -1285,6 +1285,7 @@ def fast_decode_plan(
12851285
except Exception as e:
12861286
# Log the error for debugging
12871287
import logging
1288+
12881289
logging.error(f"Error in tensor core plan: {e}")
12891290
raise
12901291
else:
@@ -1310,6 +1311,7 @@ def fast_decode_plan(
13101311
except Exception as e:
13111312
# Log the error for debugging
13121313
import logging
1314+
13131315
logging.error(f"Error in standard plan: {e}")
13141316
raise
13151317

@@ -1319,4 +1321,3 @@ def fast_decode_plan(
13191321
self._sm_scale = sm_scale
13201322
self._rope_scale = rope_scale
13211323
self._rope_theta = rope_theta
1322-

python/sglang/srt/layers/attention/flashinfer_mla_backend.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from typing import TYPE_CHECKING, Callable, Optional, Union
1515

1616
import torch
17-
import triton
1817
import torch._dynamo
18+
import triton
19+
1920
torch._dynamo.config.suppress_errors = True
2021

2122
from sglang.global_config import global_config
@@ -209,15 +210,15 @@ def init_cuda_graph_state(
209210
self.cuda_graph_kv_lens = torch.ones(
210211
(max_bs,), dtype=torch.int32, device=self.device
211212
)
212-
213+
213214
# Force allocation by performing a small operation and synchronizing
214215
# This ensures all tensors are properly allocated in GPU memory
215216
self.cuda_graph_kv_indices[0] = 0
216217
self.cuda_graph_qo_indptr[0] = 0
217218
self.cuda_graph_kv_indptr[0] = 0
218219
self.cuda_graph_kv_lens[0] = 1
219220
torch.cuda.synchronize()
220-
221+
221222
# For fast decode plan in graph replaying
222223
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
223224
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
@@ -399,11 +400,11 @@ def forward_decode(
399400
k,
400401
v,
401402
)
402-
403+
403404
# Reshape inputs
404405
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
405406
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
406-
407+
407408
# Direct call to run without the wrapper
408409
o = decode_wrapper.run(
409410
reshaped_q[:, :, : layer.v_head_dim],
@@ -855,8 +856,9 @@ def fast_mla_decode_plan(
855856
except Exception as e:
856857
# Log error for debugging
857858
import logging
859+
858860
logging.error(f"Error in MLA plan: {e}")
859-
861+
860862
# Try alternate version with more arguments if needed
861863
try:
862864
self._cached_module.plan(
@@ -865,7 +867,7 @@ def fast_mla_decode_plan(
865867
self._pin_memory_int_workspace_buffer,
866868
qo_indptr_cpu,
867869
kv_indptr_cpu,
868-
kv_indices, # Include kv_indices which was missing
870+
kv_indices, # Include kv_indices which was missing
869871
kv_len_arr_cpu,
870872
num_heads,
871873
head_dim_ckv,

scripts/ci_install_dependency.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
# Install the dependency in CI.
33
set -euxo pipefail
44

5-
# Use repo from environment variables, passed from GitHub Actions
6-
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python}"
7-
85
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
96
bash "${SCRIPT_DIR}/killall_sglang.sh"
107

@@ -18,9 +15,8 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
1815
# Update pip
1916
pip install --upgrade pip
2017

21-
# Install flashinfer and sgl-kernel
22-
pip install flashinfer_python==0.2.5 --find-links ${FLASHINFER_REPO} --no-cache-dir
23-
pip install sgl-kernel==0.0.9.post1 --no-cache-dir
18+
# Install sgl-kernel
19+
pip install sgl-kernel==0.1.0 --no-cache-dir
2420

2521
# Install the main package
2622
pip install -e "python[all]"

sgl-kernel/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ if [ ${CUDA_VERSION} = "12.8" ]; then
1010
TORCH_INSTALL="pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION//.}"
1111
else
1212
DOCKER_IMAGE="pytorch/manylinux-builder:cuda${CUDA_VERSION}"
13-
TORCH_INSTALL="pip install --no-cache-dir torch==2.6.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.}"
13+
TORCH_INSTALL="pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.}"
1414
fi
1515

1616
docker run --rm \

0 commit comments

Comments
 (0)