Skip to content

Commit 7e751c9

Browse files
[BugFix] Fix chunked prefill (#3759)
* add error traceback info * update error msg * update code * default enable chunked prefill * update code * update code * add envs * update code * update enable chunked_prefill * update code * update code * update code * update code * update code --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1 parent 27f2e7a commit 7e751c9

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

.github/workflows/_base_test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ jobs:
134134
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
135135
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
136136
-e "FLASK_PORT=${FLASK_PORT}" \
137-
-e "FD_FORCE_CHUNKED_PREFILL=1" \
138137
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
139138
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
140139
-v "${CACHE_DIR}/.cache:/root/.cache" \

fastdeploy/config.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,23 +1233,14 @@ def postprocess(self):
12331233

12341234
self.paddle_commit_id = paddle.version.commit
12351235

1236-
if self.cache_config.enable_chunked_prefill:
1237-
self.force_chunked_prefill = int(envs.FD_FORCE_CHUNKED_PREFILL)
1238-
if (
1239-
self.speculative_config is not None
1240-
and self.speculative_config.method in ["mtp"]
1241-
and not self.force_chunked_prefill
1242-
):
1243-
self.cache_config.enable_chunked_prefill = False
1244-
12451236
if self.max_num_batched_tokens is None:
1246-
if self.cache_config.enable_chunked_prefill:
1247-
self.max_num_batched_tokens = 2048
1237+
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
1238+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
12481239
else:
1249-
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
1250-
self.max_num_batched_tokens = self.max_model_len
1240+
if self.cache_config.enable_chunked_prefill:
1241+
self.max_num_batched_tokens = 2048
12511242
else:
1252-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
1243+
self.max_num_batched_tokens = self.max_model_len
12531244

12541245
if self.long_prefill_token_threshold == 0:
12551246
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)

fastdeploy/engine/args_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
"""
1616

1717
import json
18-
import os
1918
from dataclasses import asdict, dataclass
2019
from dataclasses import fields as dataclass_fields
2120
from typing import Any, Dict, List, Optional
2221

22+
from fastdeploy import envs
2323
from fastdeploy.config import (
2424
CacheConfig,
2525
EarlyStopConfig,
@@ -243,7 +243,7 @@ class EngineArgs:
243243
Ports for rdma communication.
244244
"""
245245

246-
enable_chunked_prefill: bool = True
246+
enable_chunked_prefill: bool = False
247247
"""
248248
Flag to enable chunked prefilling.
249249
"""
@@ -981,22 +981,36 @@ def create_engine_config(self) -> FDConfig:
981981

982982
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
983983
self.tensor_parallel_size = model_cfg.tensor_parallel_size
984+
985+
speculative_cfg = self.create_speculative_config()
986+
if not self.enable_chunked_prefill:
987+
if (
988+
current_platform.is_cuda()
989+
and self.splitwise_role == "mixed"
990+
and (speculative_cfg is None or speculative_cfg.method not in ["mtp"])
991+
):
992+
# default enable chunked prefill
993+
self.enable_chunked_prefill = True
994+
995+
self.disable_chunked_prefill = int(envs.FD_DISABLE_CHUNKED_PREFILL)
996+
if self.disable_chunked_prefill:
997+
self.enable_chunked_prefill = False
998+
984999
if self.max_num_batched_tokens is None:
985-
if self.enable_chunked_prefill:
986-
self.max_num_batched_tokens = 2048
1000+
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
1001+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
9871002
else:
988-
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
989-
self.max_num_batched_tokens = self.max_model_len
1003+
if self.enable_chunked_prefill:
1004+
self.max_num_batched_tokens = 2048
9901005
else:
991-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
1006+
self.max_num_batched_tokens = self.max_model_len
9921007

9931008
all_dict = asdict(self)
9941009
all_dict["model_cfg"] = model_cfg
9951010
cache_cfg = CacheConfig(all_dict)
9961011
load_cfg = LoadConfig(all_dict)
9971012
parallel_cfg = ParallelConfig(all_dict)
9981013
scheduler_cfg = self.create_scheduler_config()
999-
speculative_cfg = self.create_speculative_config()
10001014
graph_opt_cfg = self.create_graph_optimization_config()
10011015
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
10021016
moba_attention_config = self.create_moba_attention_config()

fastdeploy/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@
9393
# enable multi api server
9494
"FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))),
9595
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
96-
# force enable chunked prefill
97-
"FD_FORCE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_FORCE_CHUNKED_PREFILL", "0"))),
96+
# force disable default chunked prefill
97+
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
9898
}
9999

100100

0 commit comments

Comments
 (0)