@@ -75,7 +75,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
7575 self .output_scores = paddle .full (
7676 shape = [MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1 ), 1 ], fill_value = 0.0 , dtype = "float32"
7777 )
78- self .output_ranks = paddle .full (shape = [MAX_BSZ * MAX_DRAFT_TOKENS ], fill_value = 0 , dtype = "int64" )
78+ self .output_ranks = paddle .full (shape = [MAX_BSZ * MAX_DRAFT_TOKENS ], fill_value = 0 , dtype = "int64" )
7979 else :
8080 self .output_tokens = paddle .full (
8181 shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 ],
@@ -85,7 +85,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
8585 elif self .use_logprobs :
8686 self .output_tokens = paddle .full (shape = [MAX_BSZ * (K + 1 ) + 2 , 1 ], fill_value = 2 , dtype = "int64" )
8787 self .output_scores = paddle .full (shape = [MAX_BSZ * (K + 1 ), 1 ], fill_value = 0.0 , dtype = "float32" )
88- self .output_ranks = paddle .full (shape = [MAX_BSZ ], fill_value = 0 , dtype = "int64" )
88+ self .output_ranks = paddle .full (shape = [MAX_BSZ ], fill_value = 0 , dtype = "int64" )
8989 else :
9090 self .output_tokens = paddle .full (shape = [MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" )
9191 self .worker = None
@@ -323,29 +323,35 @@ def process_sampling_results(self):
323323 get_output_ep ,
324324 get_output_topk ,
325325 speculate_get_output ,
326+ speculate_get_output_topk ,
326327 )
327328 rank_id = self .cfg .parallel_config .local_data_parallel_id
328329
329330 while True :
330331 try :
331332 is_blocking = True
332333 if self .speculative_decoding :
333- if (
334- self .cfg .parallel_config .enable_expert_parallel
335- and self .cfg .parallel_config .data_parallel_size > 1
336- ):
337- if self .use_logprobs :
338- # TODO speculate_get_output_with_topk
339- pass
340- else :
341- speculate_get_output (self .output_tokens , rank_id , is_blocking , True )
342- elif self .use_logprobs :
343- # TODO speculate_get_output_with_topk
344- pass
334+ if self .use_logprobs :
335+ speculate_get_output_topk (
336+ self .output_tokens ,
337+ self .output_scores ,
338+ self .output_ranks ,
339+ K ,
340+ rank_id ,
341+ is_blocking ,
342+ )
343+ if self .output_tokens [0 , 0 ] == - 2 :
344+ continue
345345 else :
346- speculate_get_output (self .output_tokens , rank_id , is_blocking , False )
347- if self .output_tokens [0 ] == - 2 :
348- continue
346+ if (
347+ self .cfg .parallel_config .enable_expert_parallel
348+ and self .cfg .parallel_config .data_parallel_size > 1
349+ ):
350+ speculate_get_output (self .output_tokens , rank_id , is_blocking , True )
351+ else :
352+ speculate_get_output (self .output_tokens , rank_id , is_blocking , False )
353+ if self .output_tokens [0 ] == - 2 :
354+ continue
349355 else :
350356 if self .use_logprobs :
351357 get_output_topk (
@@ -400,18 +406,21 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3):
400406 try :
401407 if self .cfg .speculative_config .method and self .use_logprobs :
402408 if mtype == 3 : # target
403- has_finished = any (r .finished for r in batch_result )
404- if has_finished :
409+ finished_batch_result , unfinished_batch_result = [], []
410+ for r in batch_result :
411+ (finished_batch_result if r .finished else unfinished_batch_result ).append (r )
412+ if finished_batch_result :
405413 self .cached_generated_tokens .put_results (batch_result )
406414 else :
407- self ._batch_result_buffer = batch_result
415+ self ._batch_result_buffer = unfinished_batch_result
408416 elif mtype == 4 : # draft
409417 target_batch_result = []
410418 draft_batch_result = batch_result
411- for target , decode in zip (self ._batch_result_buffer , draft_batch_result ):
412- target .outputs .draft_top_logprobs = decode .outputs .draft_top_logprobs
413- target_batch_result .append (target )
414- self ._batch_result_buffer = None
419+ if self ._batch_result_buffer is not None :
420+ for target , decode in zip (self ._batch_result_buffer , draft_batch_result ):
421+ target .outputs .draft_top_logprobs = decode .outputs .draft_top_logprobs
422+ target_batch_result .append (target )
423+ self ._batch_result_buffer = None
415424 self .cached_generated_tokens .put_results (target_batch_result )
416425 else :
417426 self .cached_generated_tokens .put_results (batch_result )
@@ -671,12 +680,13 @@ def _process_batch_output(self):
671680 result .outputs .draft_top_logprobs .logprob_token_ids .extend ([topk_token_ids ])
672681 result .outputs .draft_top_logprobs .logprobs .extend ([topk_logprobs ])
673682 result .outputs .draft_top_logprobs .sampled_token_ranks .extend ([sampled_rank ])
674- if token_id in task .eos_token_ids or is_prefill or recovery_stop :
683+ if mtype == 3 and ( token_id in task .eos_token_ids or is_prefill or recovery_stop ) :
675684 result .finished = True
676685 if recovery_stop :
677686 result .error_msg = "Recover is not supported, the result is incomplete!"
678687 llm_logger .info (
679- f"Request: { task_id } finished, number of " f"generated tokens: { self .tokens_counter [task_id ]} ."
688+ f"Request: { task_id } finished, number of "
689+ f"generated tokens: { self .tokens_counter [task_id ]} , token_id:{ token_id } ,is_prefill:{ is_prefill } ,recovery_stop:{ recovery_stop } "
680690 )
681691 llm_logger .info (
682692 f"Request: { task_id } token ratio: { self .tokens_counter [task_id ] / (time .time () - task .inference_start_time )} "
0 commit comments