Skip to content

[feature] enable pre compile jit deep_gemm #5580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 21, 2025
136 changes: 136 additions & 0 deletions python/sglang/compile_deep_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Compile DeepGEMM Kernels for a model with specify server arguments

This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
It accepts server arguments (the same as launch_server.py).

Usage:
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code

"""

import argparse
import dataclasses
import multiprocessing
import os
import time

import requests

from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.srt.warmup import warmup

multiprocessing.set_start_method("spawn", force=True)

# Reduce warning
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"


@dataclasses.dataclass
class CompileArgs:
timeout: int = 3600

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)


@warmup("compile-deep-gemm")
async def warm_up_compile(tokenizer_manager: TokenizerManager):
print("\nGenerate warm up request for compiling DeepGEMM...\n")
generate_req_input = GenerateReqInput(
input_ids=[0, 1, 2, 3],
sampling_params={
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
)
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()


def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)


def launch_server_process_and_send_one_request(
server_args: ServerArgs, compile_args: CompileArgs
):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = compile_args.timeout

start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200:
return proc
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError(
"DeepGEMM Kernels compilation timeout."
"\n\nFeel free and please restart the command."
)


def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
# Disbale cuda graph and torch compile to save time
server_args.disable_cuda_graph = True
server_args.enable_torch_compile = False
print(f"Disable CUDA Graph and Torch Compile to save time...")

# Set watchdog timeout to compile_args.timeout because compilation will take a long time
server_args.watchdog_timeout = compile_args.timeout
server_args.warmups = "compile-deep-gemm"


def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
print(
"Begin DeepGEMM Kernels compilation...\n"
"It may take a long time and timeout maybe raised "
"while the compilation is still in progress.\n"
"Just feel free to restart the command "
"until the compilation is fully finished.\n"
)

proc = launch_server_process_and_send_one_request(server_args, compile_args)

kill_process_tree(proc.pid)

print("\nDeepGEMM Kernels compilation finished successfully.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
CompileArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
compile_args = CompileArgs.from_cli_args(args)

refine_server_args(server_args, compile_args)

run_compile(server_args, compile_args)
Loading
Loading