@@ -814,11 +814,11 @@ def alloc_paged_token_slots_decode(
814
814
last_loc : torch .Tensor ,
815
815
backup_state : bool = False ,
816
816
):
817
- if (
818
- self . token_to_kv_pool_allocator . available_size ()
819
- < len ( seq_lens ) * self .token_to_kv_pool_allocator .page_size
820
- ):
821
- if self . tree_cache is not None :
817
+ if self . tree_cache is not None :
818
+ if (
819
+ self .token_to_kv_pool_allocator .available_size ()
820
+ < len ( seq_lens ) * self . token_to_kv_pool_allocator . page_size
821
+ ) :
822
822
self .tree_cache .evict (
823
823
len (seq_lens ) * self .token_to_kv_pool_allocator .page_size ,
824
824
)
@@ -1116,17 +1116,25 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
1116
1116
# TODO (lianmin): Revisit this. It should be seq_len - 1
1117
1117
self .extend_logprob_start_lens .extend ([0 ] * running_bs )
1118
1118
1119
- def check_decode_mem (self , buf_multiplier = 1 ):
1120
- bs = len (self .reqs ) * buf_multiplier
1121
- if self .token_to_kv_pool_allocator .available_size () >= bs :
1122
- return True
1119
+ def new_page_count_next_decode (self ):
1120
+ page_size = self .token_to_kv_pool_allocator .page_size
1121
+ if page_size == 1 :
1122
+ return len (self .reqs )
1123
+ return sum (1 for req in self .reqs if req .seqlen % page_size == 0 )
1123
1124
1124
- self .tree_cache .evict (bs )
1125
+ def check_decode_mem (self , buf_multiplier = 1 ):
1126
+ tokens_required = (
1127
+ self .new_page_count_next_decode ()
1128
+ * buf_multiplier
1129
+ * self .token_to_kv_pool_allocator .page_size
1130
+ )
1125
1131
1126
- if self .token_to_kv_pool_allocator .available_size () >= bs :
1132
+ if self .token_to_kv_pool_allocator .available_size () >= tokens_required :
1127
1133
return True
1128
1134
1129
- return False
1135
+ self .tree_cache .evict (tokens_required )
1136
+
1137
+ return self .token_to_kv_pool_allocator .available_size () >= tokens_required
1130
1138
1131
1139
def retract_decode (self , server_args : ServerArgs ):
1132
1140
"""Retract the decoding requests when there is not enough memory."""
0 commit comments