Skip to content

Commit 0259f4e

Browse files
xiezhq-hermannmerrymercy
authored andcommitted
Fix oom error for large page size (sgl-project#4913)
Co-authored-by: Lianmin Zheng <[email protected]>
1 parent acc9ae6 commit 0259f4e

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -814,11 +814,11 @@ def alloc_paged_token_slots_decode(
814814
last_loc: torch.Tensor,
815815
backup_state: bool = False,
816816
):
817-
if (
818-
self.token_to_kv_pool_allocator.available_size()
819-
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
820-
):
821-
if self.tree_cache is not None:
817+
if self.tree_cache is not None:
818+
if (
819+
self.token_to_kv_pool_allocator.available_size()
820+
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
821+
):
822822
self.tree_cache.evict(
823823
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
824824
)
@@ -1116,17 +1116,25 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
11161116
# TODO (lianmin): Revisit this. It should be seq_len - 1
11171117
self.extend_logprob_start_lens.extend([0] * running_bs)
11181118

1119-
def check_decode_mem(self, buf_multiplier=1):
1120-
bs = len(self.reqs) * buf_multiplier
1121-
if self.token_to_kv_pool_allocator.available_size() >= bs:
1122-
return True
1119+
def new_page_count_next_decode(self):
1120+
page_size = self.token_to_kv_pool_allocator.page_size
1121+
if page_size == 1:
1122+
return len(self.reqs)
1123+
return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
11231124

1124-
self.tree_cache.evict(bs)
1125+
def check_decode_mem(self, buf_multiplier=1):
1126+
tokens_required = (
1127+
self.new_page_count_next_decode()
1128+
* buf_multiplier
1129+
* self.token_to_kv_pool_allocator.page_size
1130+
)
11251131

1126-
if self.token_to_kv_pool_allocator.available_size() >= bs:
1132+
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
11271133
return True
11281134

1129-
return False
1135+
self.tree_cache.evict(tokens_required)
1136+
1137+
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
11301138

11311139
def retract_decode(self, server_args: ServerArgs):
11321140
"""Retract the decoding requests when there is not enough memory."""

test/srt/test_eagle_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _test_acc_length(self, engine):
144144
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
145145
self.assertGreater(acc_length, 3.6)
146146
else:
147-
self.assertGreater(acc_length, 2.6)
147+
self.assertGreater(acc_length, 2.5)
148148

149149

150150
class TestEAGLEEngineTokenMap(TestEAGLEEngine):

0 commit comments

Comments
 (0)