@@ -108,18 +108,24 @@ class PrefillAdder:
108108 def __init__ (
109109 self ,
110110 tree_cache : BasePrefixCache ,
111+ running_batch : ScheduleBatch ,
112+ new_token_ratio : float ,
111113 rem_total_tokens : int ,
112114 rem_input_tokens : int ,
113115 rem_chunk_tokens : Optional [int ],
114116 mixed_with_decode_tokens : int = 0 ,
115117 ):
116118 self .tree_cache = tree_cache
119+ self .running_batch = running_batch
120+ self .new_token_ratio = new_token_ratio
117121 self .rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
122+ self .total_tokens = rem_total_tokens
118123 self .rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
119124 self .rem_chunk_tokens = rem_chunk_tokens
120125 if self .rem_chunk_tokens is not None :
121126 self .rem_chunk_tokens -= mixed_with_decode_tokens
122127
128+ self .req_states = None
123129 self .can_run_list = []
124130 self .new_inflight_req = None
125131 self .log_hit_tokens = 0
@@ -136,16 +142,14 @@ def no_remaining_tokens(self):
136142 )
137143 )
138144
139- def remove_running_tokens (
140- self , running_batch : ScheduleBatch , new_token_ratio : float
141- ):
145+ def remove_running_tokens (self , running_batch : ScheduleBatch ):
142146 self .rem_total_tokens -= sum (
143147 [
144148 min (
145149 (r .sampling_params .max_new_tokens - len (r .output_ids )),
146150 CLIP_MAX_NEW_TOKENS ,
147151 )
148- * new_token_ratio
152+ * self . new_token_ratio
149153 for r in running_batch .reqs
150154 ]
151155 )
@@ -161,7 +165,29 @@ def _prefill_one_req(
161165 self .log_hit_tokens += prefix_len
162166 self .log_input_tokens += extend_input_len
163167
168+ def add_inflight_req_ignore_eos (self , req : Req ):
169+ truncated = req .extend_input_len > self .rem_chunk_tokens
170+ req .extend_input_len = min (req .extend_input_len , self .rem_chunk_tokens )
171+ req .fill_ids = req .fill_ids [: len (req .prefix_indices ) + req .extend_input_len ]
172+ self .can_run_list .append (req )
173+
174+ self ._prefill_one_req (
175+ 0 ,
176+ req .extend_input_len ,
177+ (
178+ min (req .sampling_params .max_new_tokens , CLIP_MAX_NEW_TOKENS )
179+ if not truncated
180+ else 0
181+ ),
182+ )
183+
184+ # Return if chunked prefill not finished
185+ return req if truncated else None
186+
164187 def add_inflight_req (self , req : Req ):
188+ if req .sampling_params .ignore_eos :
189+ return self .add_inflight_req_ignore_eos (req )
190+
165191 truncated = req .extend_input_len > self .rem_chunk_tokens
166192 req .extend_input_len = min (req .extend_input_len , self .rem_chunk_tokens )
167193 req .fill_ids = req .fill_ids [: len (req .prefix_indices ) + req .extend_input_len ]
@@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode):
190216 delta = self .tree_cache .dec_lock_ref (last_node )
191217 self .rem_total_tokens += delta
192218
219+ def add_one_req_ignore_eos (self , req : Req ):
220+ def get_req_state (r ):
221+ new_token_ratio = (
222+ 1.0 if r .sampling_params .ignore_eos else self .new_token_ratio
223+ )
224+ tokens_left = r .sampling_params .max_new_tokens * new_token_ratio - len (
225+ r .output_ids
226+ )
227+ tokens_occupied = len (r .origin_input_ids ) + len (r .output_ids )
228+
229+ if tokens_left > 0 :
230+ return (tokens_left , tokens_occupied )
231+
232+ return None
233+
234+ if self .req_states is None :
235+ self .req_states = []
236+ if self .running_batch is not None :
237+ for r in self .running_batch .reqs :
238+ state = get_req_state (r )
239+ if state is not None :
240+ self .req_states .append (state )
241+ for r in self .can_run_list :
242+ state = get_req_state (r )
243+ if state is not None :
244+ self .req_states .append (state )
245+ state = get_req_state (req )
246+ if state is not None :
247+ self .req_states .append (state )
248+
249+ self .req_states .sort (key = lambda x : x [0 ])
250+ else :
251+ state = get_req_state (req )
252+ if state is not None :
253+ for i , (tokens_left , tokens_occupied ) in enumerate (self .req_states ):
254+ if tokens_left >= state [0 ]:
255+ self .req_states .insert (i , state )
256+ break
257+ else :
258+ self .req_states .append (state )
259+
260+ tokens_freed = 0
261+ for i , (tokens_left , tokens_occupied ) in enumerate (self .req_states ):
262+ decode_steps = (
263+ self .req_states [i + 1 ][0 ]
264+ if i + 1 < len (self .req_states )
265+ else tokens_left
266+ )
267+ bs = len (self .req_states ) - i
268+ if self .total_tokens + tokens_freed - decode_steps * bs <= 0 :
269+ return False
270+ tokens_freed += tokens_occupied
271+
272+ if req .extend_input_len <= self .rem_chunk_tokens :
273+ self .can_run_list .append (req )
274+ self ._prefill_one_req (
275+ 0 ,
276+ req .extend_input_len ,
277+ min (req .sampling_params .max_new_tokens , CLIP_MAX_NEW_TOKENS ),
278+ )
279+ else :
280+ # Chunked prefill
281+ trunc_len = self .rem_chunk_tokens
282+ req .extend_input_len = trunc_len
283+ req .fill_ids = req .fill_ids [:trunc_len ]
284+ self .can_run_list .append (req )
285+ self .new_inflight_req = req
286+ self ._prefill_one_req (0 , trunc_len , 0 )
287+
288+ return True
289+
193290 def add_one_req (self , req : Req ):
291+ if req .sampling_params .ignore_eos and self .tree_cache .disable :
292+ return self .add_one_req_ignore_eos (req )
293+
194294 total_tokens = req .extend_input_len + min (
195295 req .sampling_params .max_new_tokens , CLIP_MAX_NEW_TOKENS
196296 )
@@ -233,4 +333,4 @@ def add_one_req(self, req: Req):
233333 self .tree_cache .inc_lock_ref (req .last_node )
234334 self ._prefill_one_req (prefix_len , trunc_len , 0 )
235335
236- return True
336+ return True and not self . no_remaining_tokens ()
0 commit comments