diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 5be0b7dc85fa..aa3c1a81790f 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -139,6 +139,7 @@ def __init__(self, model_runner: ModelRunner): super().__init__() self.pad_slot_id = PAD_SLOT_ID self.device = model_runner.device + self.topk = model_runner.server_args.speculative_eagle_topk or 0 self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool self.forward_metadata: ForwardMetadata = None self.state_indices_list = [] @@ -185,7 +186,7 @@ def _forward_metadata(self, forward_batch: ForwardBatch): device=forward_batch.input_ids.device, ) - if forward_batch.spec_info.topk > 1: + if self.topk > 1: retrieve_next_token = forward_batch.spec_info.retrive_next_token retrieve_next_sibling = forward_batch.spec_info.retrive_next_sibling # retrieve_next_token is None during dummy run so skip tensor creation @@ -482,7 +483,7 @@ def _capture_metadata( self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask - if forward_mode.is_target_verify() and spec_info.topk > 1: + if forward_mode.is_target_verify() and self.topk > 1: # They are None during cuda graph capture so skip the copy_... # self.retrieve_next_token_list[bs - 1].copy_(spec_info.retrive_next_token) # self.retrieve_next_sibling_list[bs - 1].copy_(spec_info.retrive_next_sibling) @@ -543,7 +544,7 @@ def _replay_metadata( raise ValueError(f"Invalid forward mode: {forward_mode=}") # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask - if forward_mode.is_target_verify() and spec_info.topk > 1: + if forward_mode.is_target_verify() and self.topk > 1: bs_without_pad = spec_info.retrive_next_token.shape[0] self.retrieve_next_token_list[bs - 1][:bs_without_pad].copy_( spec_info.retrive_next_token