diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 33aab232fe..bf7371fff6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -278,6 +278,10 @@ def __exit__(self, exc_type, exc_value, traceback): self.shutdown() return False + def flush_cache(self): + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.tokenizer_manager.flush_cache()) + def start_profile(self): loop = asyncio.get_event_loop() loop.run_until_complete(self.tokenizer_manager.start_profile()) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1f93b475c2..2e5944a70b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -310,11 +310,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" - _global_state.tokenizer_manager.flush_cache() + ret = await _global_state.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 00affa0a4e..d2d5c591a2 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -665,10 +665,15 @@ class BatchEmbeddingOut: @dataclass -class FlushCacheReq: +class FlushCacheReqInput: pass +@dataclass +class FlushCacheReqOutput: + success: bool + + @dataclass class UpdateWeightFromDiskReqInput: # The model path with the new weights diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 383cd68094..24f5d4aa1a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -59,7 +59,8 @@ CloseSessionReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, + FlushCacheReqOutput, GetInternalStateReq, GetInternalStateReqOutput, GetWeightsByNameReqInput, @@ -400,7 +401,7 @@ def __init__( [ (TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReq, self.flush_cache_wrapped), + (FlushCacheReqInput, self.flush_cache_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -1652,8 +1653,9 @@ def watchdog_thread(self): time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) - def flush_cache_wrapped(self, recv_req: FlushCacheReq): - self.flush_cache() + def flush_cache_wrapped(self, recv_req: FlushCacheReqInput): + success = self.flush_cache() + return FlushCacheReqOutput(success=success) def flush_cache(self): """Flush the memory pool and cache.""" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 33afffbd6d..93e74ad27f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -62,7 +62,8 @@ EmbeddingReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, - FlushCacheReq, + FlushCacheReqInput, + FlushCacheReqOutput, GenerateReqInput, GetInternalStateReq, GetInternalStateReqOutput, @@ -258,6 +259,9 @@ def __init__( self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.start_profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -308,6 +312,10 @@ def __init__( ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + FlushCacheReqOutput, + self.flush_cache_communicator.handle_recv, + ), ( ProfileReqOutput, self.start_profile_communicator.handle_recv, @@ -616,9 +624,8 @@ async def _handle_batch_request( except StopAsyncIteration: pass - def flush_cache(self): - req = FlushCacheReq() - self.send_to_scheduler.send_pyobj(req) + async def flush_cache(self) -> FlushCacheReqOutput: + return await self.flush_cache_communicator(FlushCacheReqInput()) def abort_request(self, rid: str): if rid not in self.rid_to_state: