Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d587f21
Support prefill in Cudagraph
littledgg Aug 18, 2025
0e5e305
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
littledgg Aug 18, 2025
63f1256
Refactor GetBlockShapeAndSplitKVBlock Kernel V2
littledgg Aug 19, 2025
109d4f2
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1
littledgg Aug 20, 2025
7ba1c17
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2
littledgg Aug 20, 2025
1e60bfd
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3
littledgg Aug 20, 2025
fd41ba3
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4
littledgg Aug 20, 2025
5d52620
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5
littledgg Aug 20, 2025
39dd695
Solve problem about encoder_num_blocks_x_cpu
littledgg Aug 22, 2025
4bc60cc
Add early-exit mechanism for attention kernel
littledgg Aug 25, 2025
d267feb
fix test case about append-attention
littledgg Aug 25, 2025
2a40480
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
littledgg Aug 25, 2025
edbbbf0
Update testcode, Add annotations to related tensors
littledgg Aug 26, 2025
86d1a94
solve conflict
littledgg Aug 26, 2025
cbca27d
move get_input_length_list
littledgg Aug 26, 2025
3f40024
solve conflict
littledgg Aug 26, 2025
62a3f05
solve test_code
littledgg Aug 27, 2025
0f92c63
Add annotations about early-exit for attention kernel
littledgg Aug 27, 2025
f57a1ca
Add annotations about early-exit for attention kernel2
littledgg Aug 27, 2025
684f6c9
solve comment
littledgg Aug 27, 2025
edb923c
solve conflict
littledgg Sep 1, 2025
960c6c9
Merge branch 'develop' into prefill_in_cudagraph
gongshaotian Sep 2, 2025
5df6078
solve conflict and test case
littledgg Sep 4, 2025
c7b3d5b
solve conflict
littledgg Sep 4, 2025
8f8011d
solve mtp
littledgg Sep 5, 2025
e49ff4a
Merge branch 'develop' into prefill_in_cudagraph
littledgg Sep 5, 2025
d15e4fd
Merge branch 'develop' into prefill_in_cudagraph
littledgg Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,10 @@ def __init__(
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.cudagraph_only_prefill: bool = False
"""When cudagraph_only_prefill is False, only capture decode-only.
When cudagraph_only_prefill is True, only capture prefill-only.
Now don't support capture both decode-only and prefill-only"""
self.full_cuda_graph: bool = True

self.max_capture_size: int = None
Expand All @@ -496,13 +500,13 @@ def __init__(

self.check_legality_parameters()

def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
Expand Down Expand Up @@ -950,7 +954,11 @@ def __post_init__(self):
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)

if self.graph_opt_config.cudagraph_only_prefill:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
else:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)

# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:
Expand Down
57 changes: 50 additions & 7 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ def __init__(

self.rank, self.device_id = init_rank_and_device_id(fd_config)

self.share_inputs = {}
self.share_inputs["encoder_batch_ids"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["encoder_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu
self.share_inputs["kv_batch_ids"] = paddle.full(shape=[self.max_seq_len], fill_value=0, dtype="int32") # gpu
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full(
shape=[self.max_seq_len], fill_value=0, dtype="int32"
) # gpu
self.share_inputs["kv_num_blocks"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu
self.share_inputs["max_len_kv"] = paddle.full(shape=[1], fill_value=0, dtype="int32").cpu() # cpu

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = AppendAttentionMetadata()
Expand All @@ -140,13 +155,20 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
temp_encoder_batch_ids,
temp_encoder_tile_ids_per_batch,
temp_encoder_num_blocks,
temp_kv_batch_ids,
temp_kv_tile_ids_per_batch,
temp_kv_num_blocks,
temp_max_len_kv,
# metadata.encoder_batch_ids,
# metadata.encoder_tile_ids_per_batch,
# metadata.encoder_num_blocks,
# metadata.kv_batch_ids,
# metadata.kv_tile_ids_per_batch,
# metadata.kv_num_blocks,
# metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
Expand All @@ -162,6 +184,27 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.speculate_max_draft_token_num + 1,
)

self.share_inputs["encoder_batch_ids"].copy_(temp_encoder_batch_ids, False)
metadata.encoder_batch_ids = self.share_inputs["encoder_batch_ids"]

self.share_inputs["encoder_tile_ids_per_batch"].copy_(temp_encoder_tile_ids_per_batch, False)
metadata.encoder_tile_ids_per_batch = self.share_inputs["encoder_tile_ids_per_batch"]

self.share_inputs["encoder_num_blocks"].copy_(temp_encoder_num_blocks, False)
metadata.encoder_num_blocks = self.share_inputs["encoder_num_blocks"]

self.share_inputs["kv_batch_ids"].copy_(temp_kv_batch_ids, False)
metadata.kv_batch_ids = self.share_inputs["kv_batch_ids"]

self.share_inputs["kv_tile_ids_per_batch"].copy_(temp_kv_tile_ids_per_batch, False)
metadata.kv_tile_ids_per_batch = self.share_inputs["kv_tile_ids_per_batch"]

self.share_inputs["kv_num_blocks"].copy_(temp_kv_num_blocks, False)
metadata.kv_num_blocks = self.share_inputs["kv_num_blocks"]

self.share_inputs["max_len_kv"].copy_(temp_max_len_kv, False)
metadata.max_len_kv = self.share_inputs["max_len_kv"]

# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.pd_disaggregation_mode == "per_chunk":
Expand Down
73 changes: 61 additions & 12 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill

# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
Expand Down Expand Up @@ -166,6 +167,15 @@ def exist_prefill(self):
else:
return 0

def exist_decode(self):
"""
check whether decode stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0:
return 1
else:
return 0

def _init_speculative_proposer(self):
"""
Init speculative proposer
Expand Down Expand Up @@ -561,7 +571,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)

input_length = int(full_length * self.cache_config.kv_cache_ratio)
input_length = int(full_length)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
Expand Down Expand Up @@ -902,7 +912,7 @@ def initialize_forward_meta(self):
caches=self.share_inputs["caches"],
)

# Update Batch type for cuda graph
# Update Batch type for cuda graph for only_decode_batch
only_decode_batch = True
prefill_exists = None
# mix ep in single node
Expand All @@ -913,12 +923,34 @@ def initialize_forward_meta(self):
only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"

self.forward_meta.step_use_cudagraph = (
only_decode_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)

# Update Batch type for cuda graph for only_prefill_batch
only_prefill_batch = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
# 收集所有 worker 的状态
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
only_prefill_batch = all(only_prefill_batch_list)

only_prefill_use_cudagraph = (
self.use_cudagraph
and self.cudagraph_only_prefill
and only_prefill_batch
and not (decode_exists if decode_exists is not None else self.exist_decode())
)

# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
self.forward_meta.step_use_cudagraph = (
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
)

# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
Expand Down Expand Up @@ -1230,7 +1262,7 @@ def _update_chunked_prefill(self, tasks):
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1

def capture_model(self) -> None:
def capture_model(self, capture_prefill: bool = False) -> None:
"""
Trigger CUDA Graph capture for all shapes in cuda graph capture list
"""
Expand All @@ -1240,14 +1272,28 @@ def capture_model(self) -> None:
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
if capture_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=1,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the batch size:{batch_size}, expected_decode_len:{expected_decode_len}"
)

time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
Expand Down Expand Up @@ -1325,6 +1371,9 @@ class at the server level, which is too granular for ModelRunner.
)
hidden_states = model_output
else:
print("传递给model的seq_lens_this_time", self.forward_meta.seq_lens_this_time)
print("input_ids", self.forward_meta.input_ids.shape)
print("self.share_inputs[ids_remove_padding].shape:", self.share_inputs["ids_remove_padding"].shape)
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def graph_optimize_and_warm_up_model(self) -> None:
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# Triger cuda grpah capture
self.model_runner.capture_model()
self.model_runner.capture_model(capture_prefill=self.fd_config.graph_opt_config.cudagraph_only_prefill)

def check_health(self) -> bool:
""" """
Expand Down
Loading