Skip to content

[Excutor] Experiment Feature-Support Prefill in cudagraph#3459

Merged
Jiang-Jia-Jun merged 27 commits intoPaddlePaddle:developfrom
littledgg:prefill_in_cudagraph
Sep 8, 2025
Merged

[Excutor] Experiment Feature-Support Prefill in cudagraph#3459
Jiang-Jia-Jun merged 27 commits intoPaddlePaddle:developfrom
littledgg:prefill_in_cudagraph

Conversation

@littledgg
Copy link
Copy Markdown
Contributor

@littledgg littledgg commented Aug 18, 2025

目前支持Prefill-Only的batch进cudagraph。在确定graph可以共用之前,只能选择要么capture decode-only的,要么capture prefill-only。

1.如何开启

想要开启,需要使用以下参数启动,重点是use_cudagraph和cudagraph_only_prefill都设为True

python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
    --max-num-seqs 64 --max-model-len 32768 \
    --port 8988 --engine-worker-queue-port 7732 \
    --metrics-port 7733 --tensor-parallel-size 1 \
    --graph-optimization-config ' {"use_cudagraph":true,"cudagraph_only_prefill":true}' \

2.如何多prefill进图

在当前动态插入的背景下,假设发送4个80 tokens的prompt,那么seq_lens_this_time第一轮是[80],第二轮是[1, 80, 80, 80],很明显只有第一轮是纯P,可以进cudagraph,第二轮就是MIX了,进不了cudagraph,可以通过修改fastdeploy/engine/engine.py中的函数_insert_task_to_worker中

tasks = self.scheduler.get_requests(
                    available_blocks=self.resource_manager.available_block_num(),
                    block_size=self.cfg.cache_config.block_size,
                    reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
                    max_num_batched_tokens=self.cfg.max_num_batched_tokens,
                    batch=num_prefill_batch,
                )

改为

                tasks = list()
                start_time = time.time()
                batch_num = 4
                while (len(tasks) < batch_num):
                    # 检查是否超时,如果已收集到任务且时间超过 0.75 秒,则退出循环
                    if len(tasks) > 0 and (time.time() - start_time) > 0.75:
                        print("===RyanDebug, Timeout (2s) reached. Exiting task collection loop. ===")
                        break
                    print("===RyanDebug, Begin to collect tasks ===")
                    print("====The self.resource_manager.available_block_num is:", self.resource_manager.available_block_num())
                    print("====The self.cfg.cache_config.block_size is:", self.cfg.cache_config.block_size)
                    print("====The self.cfg.cache_config.enc_dec_block_num is:", self.cfg.cache_config.enc_dec_block_num)
                    print("====The self.cfg.max_num_batched_tokens is:", self.cfg.max_num_batched_tokens)
                    print("===RyanDebug, num_prefill_batch is: ",8)

                    tmp_task = self.scheduler.get_requests(
                        available_blocks=5000,
                        block_size=self.cfg.cache_config.block_size,
                        reserved_output_blocks=self.cfg.cache_config.
                        enc_dec_block_num,
                        max_num_batched_tokens=self.cfg.max_num_batched_tokens,
                        batch=batch_num)
                    print("===RyanDebug, the tmp_task is :", tmp_task)

                    if isinstance(tmp_task, list):
                        if tmp_task: # 检查列表是否非空
                            tasks.extend(tmp_task)
                            start_time = time.time() # 每次成功获取任务后重置计时器
                            print("####### Reset Time ##########")
                    elif tmp_task is not None:
                        tasks.append(tmp_task)
                        start_time = time.time() # 每次成功获取任务后重置计时器
                        print("####### Reset Time ##########")

                print("===RyanDebug, Finish Fix task, the len of tasks is {} ===", len(tasks))
                print("===RyanDebug, Finish Fix task, the tasks is {} ===", tasks)

这样就是不开启动态插入的逻辑,需要等待8个prompt来(数字可更改),这8个prompt才会一起进入prefill(多个prompt纯P加速),一起进入decode。task不为0时,不更改超过0.75s,那么没满也会调度,方便调试。时间可以改短一些,不然起服务会要很久。

存在的BUG1:当按照上述方法将调度策略改为非动态插入时,如果同时开启了chunked preffill,那么由于fastdeploy/engine/engine.py中的变量self.cfg.max_num_partial_prefills默认为1,导致列表self.partial_chunked_tokens长度为2,使得在函数update_requests_chunk_size中对self.partial_chunked_tokens的访问容易越界(把上面的数字8改成任意大于1的数字时就一定越界,chunk_request_num可以说就是更改后的数字,也就是一轮规划中的request数目)。

while chunk_request_num >= 1:
chunk_size = min(
                    current_request_size[idx],
                    self.partial_chunked_tokens[chunk_request_num],#越界了
                )

解决方法:应该让动态插入也能做到第一轮塞满,但长久来看让调度器倾向于调度纯P会跟有利于cudagraph加速纯P但是这样影响解码速度。而开chunked prefill的场景(大prompt,大于8k)和cudagraph加速prefill(小prompt,小于1k)没有重叠,所以没有必要做兼容,但是为了稳定性考量,应该访问之前做越界检查。
存在的BUG2:对于算子get_block_shape_and_split_kv_block中的函数split_q_block。

__global__ void split_q_block(const int *__restrict__ seq_lens_q,
                              const int *__restrict__ seq_lens_encoder,
                              int *__restrict__ batch_ids,
                              int *__restrict__ tile_ids_per_batch,
                              int *__restrict__ num_blocks_x, const int bsz,
                              const int num_rows_per_block,
                              const int group_size) {
  if (threadIdx.x == 0) {
    int gridx = 0;
    int index = 0;
    for (uint32_t bid = 0; bid < bsz; bid++) {
      int seq_len = seq_lens_q[bid];
      if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
        seq_len = 0;
      }
      const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
      for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
        batch_ids[index] = bid;
        tile_ids_per_batch[index++] = tile_id;
      }
      gridx += loop_times;
    }
    *num_blocks_x = gridx;
  }
}

当传入seq_lens_q为[160,0]时,计算出来的num_blocks_x为dive_up(160×5,64)+dive_up(0×5,64)=13+0=13
当传入seq_lens_q为[80,80]时,计算出来的num_blocks_x为dive_up(80×5,64)+dive_up(80×5,64)=7+7=14
dummy_run capture时,是第一种情况,160的token_num,对应的cuda graph的num_blocks_x为13。
当发送两个长为80的prompt时,是第二种情况,这个160的token_num,对应的cuda graph的num_blocks_x应该为14,但是在使用num_blocks_x为13的cuda graph在推理,这会导致最后一个请求结果乱码。经过验证,14的图可以推理13的。
问题的本质是一组请求中每个请求的token数不为1时,那么这组请求和另外一组有同样的总token_num的请求对应的encoder_num_blocks_x_cpu可能不一样,导致对应的graph不一样,具体的,指multi_query_append_attention_kernel的griddim.x
image
同时其他变量可能也有这种情况。MTP场景可能也有这种问题。解决方法依然是按最大的起,再加上提前退出可以解决。问题抽象为N(max_num_seq)个正整数,它们的和是M,现对每个正整数Ni的diveup(Ni × a,b)求和,其最大值就是应该起的num_blocks_x,M,N,a和b应该是确定了的。解法很简单,每个seq都正好token数为1,最后一个seq包揽剩下的所有token,由此去构造纯P时的dummy_run即可。
目前已在_dummy_prefill_inputs中解决,但是launch kernel会多launch,目测只能通过传递参数进去让其提前退出,虽然目前并不影响。

3.修改图的大小

在fastdeploy/config.py的init_with_cudagrpah_size中,512为capture prefill时最大capture size,可以手动更改。

if self.graph_opt_config.cudagraph_only_prefill:
            self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Aug 18, 2025

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Aug 18, 2025
@littledgg littledgg changed the title [Excutor] Experiment-Support Prefill in cudagraph [Excutor] Experiment Features-Support Prefill in cudagraph Aug 21, 2025
@littledgg littledgg changed the title [Excutor] Experiment Features-Support Prefill in cudagraph [Excutor] Experiment Feature-Support Prefill in cudagraph Aug 21, 2025
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Aug 25, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@82e64b1). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #3459   +/-   ##
==========================================
  Coverage           ?   49.16%           
==========================================
  Files              ?        9           
  Lines              ?      120           
  Branches           ?        8           
==========================================
  Hits               ?       59           
  Misses             ?       57           
  Partials           ?        4           
Flag Coverage Δ
diff 49.16% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=self.parallel_config.max_num_seqs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch size 的入参可以删掉了,prefill、decode、spec decode 统一把 token_num 映射为 cpature_size

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到SOT也要使用dummy_run,并且对batch数有要求,删除batch字段应该不可行

self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)

def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
def _dummy_prefill_inputs(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_tokens 在 prefill 和decode 阶段含义不一样,需要明确一下。
这里可以定位为 prefill 的 token 数

Comment on lines 56 to 63
input_length_list = []
if num_tokens < batch_size:
input_length_list = [1] * num_tokens
else:
input_length_list = [1] * (batch_size - 1)
input_length_list.append(num_tokens - batch_size + 1)
return input_length_list

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode 的时候 input_length_list 应该为 self.parallel_config.max_num_batched_tokens - input_length_list

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要跑一下deepseek的模型,确认下 mla attention 能跑prefill 捕获吗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下一个PR再支持吧

Comment on lines +83 to +93
# A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension
# using a `num_blocks` variable, and then map each thread block to a specific batch and
# data tile using `batch_ids` and `tile_ids_per_batch`.
#
# The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`)
# for variables that are logically grouped together. The mapping works as follows:
#
# Usage: `my_kernel<<<grids, ...>>>(..., batch_ids, tile_ids, ...)`
# `grids.x` = `num_blocks_cpu`
# `batch_id` = `batch_ids[blockIdx.x]`
# `tile_id` = `tile_ids[blockIdx.x]`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这注释写的太棒了

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@gongshaotian gongshaotian self-assigned this Aug 27, 2025
@littledgg
Copy link
Copy Markdown
Contributor Author

littledgg commented Sep 2, 2025

短prompt(约为10token数)情况下

原版:21.259ms
image

prefill进cudagraph:9.278ms
image

长Prompt(约为480token数)情况下:

原版:36.48ms
image
prefill进cudagraph:21.374ms
image

结果分析

短prompt(大约10个token):9.728 / 21.259 = 45.7%,加速比118%
长prompt(481个token): 21.374 / 36.48 = 58.5%,加速比70%

Copy link
Copy Markdown
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 3d0aaa5 into PaddlePaddle:develop Sep 8, 2025
14 of 17 checks passed
@littledgg littledgg deleted the prefill_in_cudagraph branch November 28, 2025 07:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants