Skip to content

Commit e300df4

Browse files
Alcanderianxwu-intel
authored andcommitted
[misc] deep_gemm fallback to NVRTC when NVCC not found (sgl-project#6252)
1 parent 5b3f2bc commit e300df4

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

python/sglang/srt/layers/quantization/deep_gemm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
if is_cuda():
1616
import deep_gemm
1717
from deep_gemm import get_num_sms
18+
from deep_gemm.jit.compiler import get_nvcc_compiler
1819
from deep_gemm.jit_kernels.gemm import get_best_configs
1920
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
2021
from deep_gemm.jit_kernels.tuner import jit_tuner
@@ -48,7 +49,17 @@ def get_enable_jit_deepgemm():
4849
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
4950
# NVRTC may have performance loss with some cases.
5051
# And NVCC JIT speed is also 9x faster in the ref commit
51-
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
52+
_USE_NVRTC_DEFAULT = "0"
53+
if _ENABLE_JIT_DEEPGEMM:
54+
try:
55+
get_nvcc_compiler()
56+
except:
57+
logger.warning(
58+
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
59+
"and may have performance loss with some cases."
60+
)
61+
_USE_NVRTC_DEFAULT = "1"
62+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
5263

5364

5465
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):

0 commit comments

Comments
 (0)