34
34
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var (
35
35
"SGL_JIT_DEEPGEMM_PRECOMPILE" , "true"
36
36
)
37
- _DO_COMPILE = get_bool_env_var ("SGL_IS_FIRST_RANK_ON_NODE" , "true" )
37
+ _DO_COMPILE_ALL = True
38
+ _IS_FIRST_RANK_ON_NODE = get_bool_env_var ("SGL_IS_FIRST_RANK_ON_NODE" , "true" )
38
39
_COMPILE_WORKERS = get_int_env_var ("SGL_JIT_DEEPGEMM_COMPILE_WORKERS" , 4 )
39
- _IN_PRE_COMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE " , "false" )
40
+ _IN_PRECOMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE " , "false" )
40
41
41
42
# Force redirect deep_gemm cache_dir
42
43
os .environ ["DG_CACHE_DIR" ] = os .getenv (
46
47
47
48
def update_deep_gemm_config (gpu_id : int , server_args : ServerArgs ):
48
49
global _BUILTIN_M_LIST
49
- global _DO_COMPILE
50
+ global _DO_COMPILE_ALL
51
+ global _IS_FIRST_RANK_ON_NODE
50
52
51
53
# Generate m_max
52
54
m_max = 1024 * 16
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
57
59
m_max = min (1024 * 128 , m_max )
58
60
_BUILTIN_M_LIST = list (range (1 , m_max + 1 ))
59
61
60
- # Check if is the first rank on node
61
- _DO_COMPILE = ServerArgs .base_gpu_id == gpu_id
62
+ _IS_FIRST_RANK_ON_NODE = ServerArgs .base_gpu_id == gpu_id
63
+
64
+ # Check if is the first rank on node.
65
+ # Default each rank will try compile all Ms to
66
+ # load all symbols at the launch stages.
67
+ # Avoid loading symbols at the serving stages.
68
+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
62
69
63
70
64
71
class DeepGemmKernelType (IntEnum ):
@@ -89,7 +96,7 @@ class DeepGemmKernelHelper:
89
96
90
97
91
98
def _compile_warning_1 ():
92
- if not _IN_PRE_COMPILE_STAGE :
99
+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE :
93
100
logger .warning (
94
101
"Entering DeepGEMM JIT Pre-Complie session. "
95
102
"And it may takes a long time(Typically 10-20 mins) "
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
276
283
query_key = (kernel_type , n , k , num_groups )
277
284
if (
278
285
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
279
- and _DO_COMPILE
286
+ and _DO_COMPILE_ALL
280
287
and _INITIALIZATION_DICT .get (query_key ) is None
281
288
):
282
289
_INITIALIZATION_DICT [query_key ] = True
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
286
293
logger .info (
287
294
f"Try DeepGEMM JIT Compiling for "
288
295
f"<{ kernel_helper .name } > N={ n } , K={ k } , num_groups={ num_groups } with all Ms."
289
- f"{ ' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else '' } "
296
+ f"{ ' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else '' } "
290
297
)
291
298
292
299
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
355
362
356
363
@contextmanager
357
364
def _log_jit_build (M : int , N : int , K : int , kernel_type : DeepGemmKernelType ):
358
- if _IN_PRE_COMPILE_STAGE :
365
+ if _IN_PRECOMPILE_STAGE :
359
366
yield
360
367
return
361
368
0 commit comments