Skip to content

Commit 2c56b6d

Browse files
committed
fix: postprocess for speculative decode
1 parent a65b22d commit 2c56b6d

File tree

1 file changed

+36
-26
lines changed

1 file changed

+36
-26
lines changed

fastdeploy/output/token_processor.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
7575
self.output_scores = paddle.full(
7676
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32"
7777
)
78-
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
78+
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
7979
else:
8080
self.output_tokens = paddle.full(
8181
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
@@ -85,7 +85,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
8585
elif self.use_logprobs:
8686
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
8787
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
88-
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
88+
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
8989
else:
9090
self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
9191
self.worker = None
@@ -323,29 +323,35 @@ def process_sampling_results(self):
323323
get_output_ep,
324324
get_output_topk,
325325
speculate_get_output,
326+
speculate_get_output_topk,
326327
)
327328
rank_id = self.cfg.parallel_config.local_data_parallel_id
328329

329330
while True:
330331
try:
331332
is_blocking = True
332333
if self.speculative_decoding:
333-
if (
334-
self.cfg.parallel_config.enable_expert_parallel
335-
and self.cfg.parallel_config.data_parallel_size > 1
336-
):
337-
if self.use_logprobs:
338-
# TODO speculate_get_output_with_topk
339-
pass
340-
else:
341-
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
342-
elif self.use_logprobs:
343-
# TODO speculate_get_output_with_topk
344-
pass
334+
if self.use_logprobs:
335+
speculate_get_output_topk(
336+
self.output_tokens,
337+
self.output_scores,
338+
self.output_ranks,
339+
K,
340+
rank_id,
341+
is_blocking,
342+
)
343+
if self.output_tokens[0, 0] == -2:
344+
continue
345345
else:
346-
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
347-
if self.output_tokens[0] == -2:
348-
continue
346+
if (
347+
self.cfg.parallel_config.enable_expert_parallel
348+
and self.cfg.parallel_config.data_parallel_size > 1
349+
):
350+
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
351+
else:
352+
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
353+
if self.output_tokens[0] == -2:
354+
continue
349355
else:
350356
if self.use_logprobs:
351357
get_output_topk(
@@ -400,18 +406,21 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3):
400406
try:
401407
if self.cfg.speculative_config.method and self.use_logprobs:
402408
if mtype == 3: # target
403-
has_finished = any(r.finished for r in batch_result)
404-
if has_finished:
409+
finished_batch_result, unfinished_batch_result = [], []
410+
for r in batch_result:
411+
(finished_batch_result if r.finished else unfinished_batch_result).append(r)
412+
if finished_batch_result:
405413
self.cached_generated_tokens.put_results(batch_result)
406414
else:
407-
self._batch_result_buffer = batch_result
415+
self._batch_result_buffer = unfinished_batch_result
408416
elif mtype == 4: # draft
409417
target_batch_result = []
410418
draft_batch_result = batch_result
411-
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
412-
target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs
413-
target_batch_result.append(target)
414-
self._batch_result_buffer = None
419+
if self._batch_result_buffer is not None:
420+
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
421+
target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs
422+
target_batch_result.append(target)
423+
self._batch_result_buffer = None
415424
self.cached_generated_tokens.put_results(target_batch_result)
416425
else:
417426
self.cached_generated_tokens.put_results(batch_result)
@@ -671,12 +680,13 @@ def _process_batch_output(self):
671680
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
672681
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
673682
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
674-
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
683+
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
675684
result.finished = True
676685
if recovery_stop:
677686
result.error_msg = "Recover is not supported, the result is incomplete!"
678687
llm_logger.info(
679-
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
688+
f"Request: {task_id} finished, number of "
689+
f"generated tokens: {self.tokens_counter[task_id]}, token_id:{token_id},is_prefill:{is_prefill},recovery_stop:{recovery_stop}"
680690
)
681691
llm_logger.info(
682692
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"

0 commit comments

Comments
 (0)