diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 8cb237cf17..12d9249981 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -268,98 +268,97 @@ def write_thread_func_direct(self): """ Directly write through KV caches to host memory without buffering. """ - with torch.cuda.stream(self.write_stream): - while not self.stop_event.is_set(): - try: - operation = self.write_queue.get(block=True, timeout=1) - self.mem_pool_host.write_page_all_layers( - operation.host_indices, - operation.device_indices, - self.mem_pool_device, - ) - self.write_stream.synchronize() - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_write_queue.put(node_id) - except Empty: - continue - except Exception as e: - logger.error(e) + torch.cuda.set_stream(self.write_stream) + while not self.stop_event.is_set(): + try: + operation = self.write_queue.get(block=True, timeout=1) + self.mem_pool_host.write_page_all_layers( + operation.host_indices, + operation.device_indices, + self.mem_pool_device, + ) + self.write_stream.synchronize() + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + if node_id != 0: + self.ack_write_queue.put(node_id) + except Empty: + continue + except Exception as e: + logger.error(e) def load_thread_func_direct(self): """ Directly load KV caches from host memory to device memory without buffering. """ - with torch.cuda.stream(self.load_stream): - while not self.stop_event.is_set(): - try: - operation = self.load_queue.get(block=True, timeout=1) - # time.sleep(18e-6 * len(operation.host_indices)) - operation.data = self.mem_pool_host.get_flat_data( - operation.host_indices - ) - self.mem_pool_device.transfer( - operation.device_indices, operation.data - ) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) - except Empty: - continue - except Exception as e: - logger.error(e) + torch.cuda.set_stream(self.load_stream) + while not self.stop_event.is_set(): + try: + operation = self.load_queue.get(block=True, timeout=1) + # time.sleep(18e-6 * len(operation.host_indices)) + operation.data = self.mem_pool_host.get_flat_data( + operation.host_indices + ) + self.mem_pool_device.transfer(operation.device_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + if node_id != 0: + self.ack_load_queue.put(node_id) + except Empty: + continue + except Exception as e: + logger.error(e) def load_thread_func_layer_by_layer(self): """ Load KV caches from host memory to device memory layer by layer. """ - with torch.cuda.stream(self.load_stream): - while not self.stop_event.is_set(): - self.load_cache_event.wait(timeout=1) - if not self.load_cache_event.is_set(): - continue - self.load_cache_event.clear() + torch.cuda.set_stream(self.load_stream) + while not self.stop_event.is_set(): + self.load_cache_event.wait(timeout=1) + if not self.load_cache_event.is_set(): + continue + self.load_cache_event.clear() - batch_operation = None - while self.load_queue.qsize() > 0: - op = self.load_queue.get(block=True) - if batch_operation is None: - batch_operation = op - else: - batch_operation.merge(op) + batch_operation = None + while self.load_queue.qsize() > 0: + op = self.load_queue.get(block=True) if batch_operation is None: - continue + batch_operation = op + else: + batch_operation.merge(op) + if batch_operation is None: + continue - self.layer_done_counter.reset() - for i in range(self.mem_pool_host.layer_num): - if self.page_size == 1: - flat_data = self.mem_pool_host.get_flat_data_by_layer( - batch_operation.host_indices, i - ) - self.mem_pool_device.transfer_per_layer( - batch_operation.device_indices, flat_data, i - ) - else: - self.mem_pool_host.load_page_per_layer( - batch_operation.host_indices, - batch_operation.device_indices, - self.mem_pool_device, - i, - ) - self.load_stream.synchronize() - self.layer_done_counter.increment() - - self.mem_pool_host.complete_io(batch_operation.host_indices) - for node_id in batch_operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) + self.layer_done_counter.reset() + for i in range(self.mem_pool_host.layer_num): + if self.page_size == 1: + flat_data = self.mem_pool_host.get_flat_data_by_layer( + batch_operation.host_indices, i + ) + self.mem_pool_device.transfer_per_layer( + batch_operation.device_indices, flat_data, i + ) + else: + self.mem_pool_host.load_page_per_layer( + batch_operation.host_indices, + batch_operation.device_indices, + self.mem_pool_device, + i, + ) + self.load_stream.synchronize() + self.layer_done_counter.increment() + + self.mem_pool_host.complete_io(batch_operation.host_indices) + for node_id in batch_operation.node_ids: + if node_id != 0: + self.ack_load_queue.put(node_id) def write_aux_func(self, no_wait=False): """ Auxiliary function to prepare the buffer for write operations. """ + torch.cuda.set_stream(self.write_stream) def _to_op(op_): assert op_.device_indices.is_cuda, "Device indices should be on GPU" @@ -370,44 +369,42 @@ def _to_op(op_): return op_ buffer = None - with torch.cuda.stream(self.write_stream): - while not self.stop_event.is_set(): - try: - operation = self.write_queue.get(block=True, timeout=1) - factor = ( - len(operation.device_indices) - // self.write_buffer.max_buffer_size - ) + while not self.stop_event.is_set(): + try: + operation = self.write_queue.get(block=True, timeout=1) + factor = ( + len(operation.device_indices) // self.write_buffer.max_buffer_size + ) - if factor >= 1: - if buffer is not None: - _to_op(buffer) - buffer = None - - if factor < 2: - _to_op(operation) - else: - split_ops = operation.split(factor) - for op_ in split_ops: - _to_op(op_) - continue - - if buffer is None: - buffer = operation - else: - buffer.merge(operation) - if ( - no_wait - or len(buffer.host_indices) >= self.write_buffer.max_buffer_size - or self.write_queue.empty() - or self.write_buffer.empty() - ): + if factor >= 1: + if buffer is not None: _to_op(buffer) buffer = None - except Empty: + + if factor < 2: + _to_op(operation) + else: + split_ops = operation.split(factor) + for op_ in split_ops: + _to_op(op_) continue - except Exception as e: - logger.error(e) + + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + no_wait + or len(buffer.host_indices) >= self.write_buffer.max_buffer_size + or self.write_queue.empty() + or self.write_buffer.empty() + ): + _to_op(buffer) + buffer = None + except Empty: + continue + except Exception as e: + logger.error(e) def load_aux_func(self): """ @@ -484,19 +481,18 @@ def write_thread_func_buffer(self): aux_thread.join() def load_thread_func_buffer(self): + torch.cuda.set_stream(self.load_stream) aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) aux_thread.start() - - with torch.cuda.stream(self.load_stream): - while not self.stop_event.is_set(): - operation = self.load_buffer.get() - if operation is None: - continue - self.mem_pool_device.transfer(operation.device_indices, operation.data) - self.mem_pool_host.complete_io(operation.host_indices) - for node_id in operation.node_ids: - if node_id != 0: - self.ack_load_queue.put(node_id) + while not self.stop_event.is_set(): + operation = self.load_buffer.get() + if operation is None: + continue + self.mem_pool_device.transfer(operation.device_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + if node_id != 0: + self.ack_load_queue.put(node_id) aux_thread.join() def evict_device(