Skip to content

Commit eebfdb9

Browse files
authored
[fix] fix potential bumpy throughtput with deepgemm (#5722)
1 parent dfb3226 commit eebfdb9

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

python/sglang/compile_deep_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
multiprocessing.set_start_method("spawn", force=True)
2828

2929
# Reduce warning
30-
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
30+
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
3131
# Force enable deep gemm
3232
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
3333
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case

python/sglang/srt/layers/quantization/deep_gemm.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
3535
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
3636
)
37-
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
37+
_DO_COMPILE_ALL = True
38+
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
3839
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
39-
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
40+
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
4041

4142
# Force redirect deep_gemm cache_dir
4243
os.environ["DG_CACHE_DIR"] = os.getenv(
@@ -46,7 +47,8 @@
4647

4748
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
4849
global _BUILTIN_M_LIST
49-
global _DO_COMPILE
50+
global _DO_COMPILE_ALL
51+
global _IS_FIRST_RANK_ON_NODE
5052

5153
# Generate m_max
5254
m_max = 1024 * 16
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
5759
m_max = min(1024 * 128, m_max)
5860
_BUILTIN_M_LIST = list(range(1, m_max + 1))
5961

60-
# Check if is the first rank on node
61-
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
62+
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
63+
64+
# Check if is the first rank on node.
65+
# Default each rank will try compile all Ms to
66+
# load all symbols at the launch stages.
67+
# Avoid loading symbols at the serving stages.
68+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
6269

6370

6471
class DeepGemmKernelType(IntEnum):
@@ -89,7 +96,7 @@ class DeepGemmKernelHelper:
8996

9097

9198
def _compile_warning_1():
92-
if not _IN_PRE_COMPILE_STAGE:
99+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
93100
logger.warning(
94101
"Entering DeepGEMM JIT Pre-Complie session. "
95102
"And it may takes a long time(Typically 10-20 mins) "
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
276283
query_key = (kernel_type, n, k, num_groups)
277284
if (
278285
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
279-
and _DO_COMPILE
286+
and _DO_COMPILE_ALL
280287
and _INITIALIZATION_DICT.get(query_key) is None
281288
):
282289
_INITIALIZATION_DICT[query_key] = True
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
286293
logger.info(
287294
f"Try DeepGEMM JIT Compiling for "
288295
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
289-
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
296+
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
290297
)
291298

292299
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
355362

356363
@contextmanager
357364
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
358-
if _IN_PRE_COMPILE_STAGE:
365+
if _IN_PRECOMPILE_STAGE:
359366
yield
360367
return
361368

0 commit comments

Comments
 (0)