Skip to content

Commit ab4a83b

Browse files
authored
Optimize schedule (#1339)
1 parent 62f15ee commit ab4a83b

File tree

2 files changed

+123
-8
lines changed

2 files changed

+123
-8
lines changed

python/sglang/srt/managers/policy_scheduler.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,24 @@ class PrefillAdder:
108108
def __init__(
109109
self,
110110
tree_cache: BasePrefixCache,
111+
running_batch: ScheduleBatch,
112+
new_token_ratio: float,
111113
rem_total_tokens: int,
112114
rem_input_tokens: int,
113115
rem_chunk_tokens: Optional[int],
114116
mixed_with_decode_tokens: int = 0,
115117
):
116118
self.tree_cache = tree_cache
119+
self.running_batch = running_batch
120+
self.new_token_ratio = new_token_ratio
117121
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
122+
self.total_tokens = rem_total_tokens
118123
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
119124
self.rem_chunk_tokens = rem_chunk_tokens
120125
if self.rem_chunk_tokens is not None:
121126
self.rem_chunk_tokens -= mixed_with_decode_tokens
122127

128+
self.req_states = None
123129
self.can_run_list = []
124130
self.new_inflight_req = None
125131
self.log_hit_tokens = 0
@@ -136,16 +142,14 @@ def no_remaining_tokens(self):
136142
)
137143
)
138144

139-
def remove_running_tokens(
140-
self, running_batch: ScheduleBatch, new_token_ratio: float
141-
):
145+
def remove_running_tokens(self, running_batch: ScheduleBatch):
142146
self.rem_total_tokens -= sum(
143147
[
144148
min(
145149
(r.sampling_params.max_new_tokens - len(r.output_ids)),
146150
CLIP_MAX_NEW_TOKENS,
147151
)
148-
* new_token_ratio
152+
* self.new_token_ratio
149153
for r in running_batch.reqs
150154
]
151155
)
@@ -161,7 +165,29 @@ def _prefill_one_req(
161165
self.log_hit_tokens += prefix_len
162166
self.log_input_tokens += extend_input_len
163167

168+
def add_inflight_req_ignore_eos(self, req: Req):
169+
truncated = req.extend_input_len > self.rem_chunk_tokens
170+
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
171+
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
172+
self.can_run_list.append(req)
173+
174+
self._prefill_one_req(
175+
0,
176+
req.extend_input_len,
177+
(
178+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
179+
if not truncated
180+
else 0
181+
),
182+
)
183+
184+
# Return if chunked prefill not finished
185+
return req if truncated else None
186+
164187
def add_inflight_req(self, req: Req):
188+
if req.sampling_params.ignore_eos:
189+
return self.add_inflight_req_ignore_eos(req)
190+
165191
truncated = req.extend_input_len > self.rem_chunk_tokens
166192
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
167193
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode):
190216
delta = self.tree_cache.dec_lock_ref(last_node)
191217
self.rem_total_tokens += delta
192218

219+
def add_one_req_ignore_eos(self, req: Req):
220+
def get_req_state(r):
221+
new_token_ratio = (
222+
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
223+
)
224+
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
225+
r.output_ids
226+
)
227+
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
228+
229+
if tokens_left > 0:
230+
return (tokens_left, tokens_occupied)
231+
232+
return None
233+
234+
if self.req_states is None:
235+
self.req_states = []
236+
if self.running_batch is not None:
237+
for r in self.running_batch.reqs:
238+
state = get_req_state(r)
239+
if state is not None:
240+
self.req_states.append(state)
241+
for r in self.can_run_list:
242+
state = get_req_state(r)
243+
if state is not None:
244+
self.req_states.append(state)
245+
state = get_req_state(req)
246+
if state is not None:
247+
self.req_states.append(state)
248+
249+
self.req_states.sort(key=lambda x: x[0])
250+
else:
251+
state = get_req_state(req)
252+
if state is not None:
253+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
254+
if tokens_left >= state[0]:
255+
self.req_states.insert(i, state)
256+
break
257+
else:
258+
self.req_states.append(state)
259+
260+
tokens_freed = 0
261+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
262+
decode_steps = (
263+
self.req_states[i + 1][0]
264+
if i + 1 < len(self.req_states)
265+
else tokens_left
266+
)
267+
bs = len(self.req_states) - i
268+
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
269+
return False
270+
tokens_freed += tokens_occupied
271+
272+
if req.extend_input_len <= self.rem_chunk_tokens:
273+
self.can_run_list.append(req)
274+
self._prefill_one_req(
275+
0,
276+
req.extend_input_len,
277+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
278+
)
279+
else:
280+
# Chunked prefill
281+
trunc_len = self.rem_chunk_tokens
282+
req.extend_input_len = trunc_len
283+
req.fill_ids = req.fill_ids[:trunc_len]
284+
self.can_run_list.append(req)
285+
self.new_inflight_req = req
286+
self._prefill_one_req(0, trunc_len, 0)
287+
288+
return True
289+
193290
def add_one_req(self, req: Req):
291+
if req.sampling_params.ignore_eos and self.tree_cache.disable:
292+
return self.add_one_req_ignore_eos(req)
293+
194294
total_tokens = req.extend_input_len + min(
195295
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
196296
)
@@ -233,4 +333,4 @@ def add_one_req(self, req: Req):
233333
self.tree_cache.inc_lock_ref(req.last_node)
234334
self._prefill_one_req(prefix_len, trunc_len, 0)
235335

236-
return True
336+
return True and not self.no_remaining_tokens()

python/sglang/srt/managers/tp_worker.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def __init__(
221221
)
222222
self.new_token_ratio = self.min_new_token_ratio
223223
self.new_token_ratio_decay = global_config.new_token_ratio_decay
224+
self.do_not_get_new_batch = False
224225

225226
def exposed_step(self, recv_reqs: List):
226227
try:
@@ -253,7 +254,13 @@ def exposed_step(self, recv_reqs: List):
253254

254255
@torch.inference_mode()
255256
def forward_step(self):
256-
new_batch = self.get_new_prefill_batch()
257+
if self.current_inflight_req is not None:
258+
self.do_not_get_new_batch = False
259+
260+
new_batch = (
261+
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
262+
)
263+
self.do_not_get_new_batch = False
257264

258265
if new_batch is not None:
259266
# Run a new prefill batch
@@ -409,14 +416,16 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
409416

410417
adder = PrefillAdder(
411418
self.tree_cache,
419+
self.running_batch,
420+
self.new_token_ratio,
412421
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
413422
self.max_prefill_tokens,
414423
self.chunked_prefill_size,
415424
num_mixed_running,
416425
)
417426

418427
if self.running_batch is not None:
419-
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
428+
adder.remove_running_tokens(self.running_batch)
420429

421430
has_inflight = self.current_inflight_req is not None
422431
if self.current_inflight_req is not None:
@@ -428,11 +437,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
428437
)
429438

430439
for req in self.waiting_queue:
440+
if adder.no_remaining_tokens():
441+
break
431442
req.init_next_round_input(None if prefix_computed else self.tree_cache)
432443
res = adder.add_one_req(req)
433444
if (
434445
not res
435-
or adder.no_remaining_tokens()
436446
or running_bs + len(adder.can_run_list) >= self.max_running_requests
437447
):
438448
break
@@ -700,6 +710,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):
700710
next_token_ids = next_token_ids.tolist()
701711

702712
# Check finish condition
713+
has_finished = False
703714
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
704715
req.completion_tokens_wo_jump_forward += 1
705716
req.output_ids.append(next_token_id)
@@ -712,6 +723,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):
712723

713724
if req.finished():
714725
self.tree_cache.cache_finished_req(req)
726+
has_finished = True
715727

716728
if req.return_logprob:
717729
req.output_token_logprobs.append(
@@ -720,6 +732,9 @@ def forward_decode_batch(self, batch: ScheduleBatch):
720732
if req.top_logprobs_num > 0:
721733
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
722734

735+
if not has_finished:
736+
self.do_not_get_new_batch = True
737+
723738
self.handle_finished_requests(batch)
724739

725740
def handle_finished_requests(self, batch: ScheduleBatch):

0 commit comments

Comments
 (0)