Skip to content
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8351e83
success run ngram
gongshaotian Aug 20, 2025
02e8384
Revert "[Code Simplification] remove cum_offsets (#3410)"
lizexu123 Aug 21, 2025
1444ba6
success run ngram5 tp4 42bs
lizexu123 Aug 21, 2025
892c0c2
success run ngram5 tp4 42bs
lizexu123 Aug 22, 2025
18d9823
merge develop
gongshaotian Sep 1, 2025
64ea2f7
mtp draft commit
gongshaotian Sep 2, 2025
3263006
enable target model in cuda graph
littledgg Sep 8, 2025
5b75ade
Merge pull request #1 from littledgg/mtp
gongshaotian Sep 8, 2025
4772a4f
add decorator for target model
zeroRains Sep 9, 2025
4a0a6df
enable draft model in cudagraph v0.5
littledgg Sep 10, 2025
ec4a2df
revert revrt cum_offset
littledgg Sep 12, 2025
529214c
Merge pull request #3 from littledgg/mtp
gongshaotian Sep 12, 2025
2dd98da
enable target model in cudagraph v0.9 And clean debug code
littledgg Sep 12, 2025
1d3ef67
Revert "success run ngram"
Sep 12, 2025
349988f
add reverted code
Sep 12, 2025
15d3103
enable target model in cudagraph v0.9
littledgg Sep 12, 2025
7f11653
solve comment
littledgg Sep 12, 2025
bb9c911
Merge pull request #4 from littledgg/mtp
gongshaotian Sep 12, 2025
d1115a7
merge remote mtp
Sep 12, 2025
77e64ed
merge develop & solve conflict
Sep 15, 2025
235b0ba
fix bid < 0
Sep 16, 2025
3516be4
Enable Target Model Padding And Draft Model in cudagraph
littledgg Sep 16, 2025
c6cdc17
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
littledgg Sep 16, 2025
167fb58
solve problem
littledgg Sep 16, 2025
4c10571
Merge pull request #5 from littledgg/mtp
gongshaotian Sep 16, 2025
89c6c83
delete rebuild padding debug note
Sep 16, 2025
fdf49de
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 16, 2025
834639a
fast compile
Sep 17, 2025
4c09b0b
Add capture list for mtp
littledgg Sep 17, 2025
9f71c0e
Merge pull request #6 from littledgg/mtp
gongshaotian Sep 18, 2025
fc6ce99
success run 256 tp1 mtp
Sep 18, 2025
8c306d8
Enable Lite TP2 Bsz256
littledgg Sep 18, 2025
cf01a97
Merge pull request #7 from littledgg/mtp
gongshaotian Sep 19, 2025
e28327a
realy enable tp2 bsz 256
littledgg Sep 22, 2025
a44e2d9
fix problem
littledgg Sep 22, 2025
00de438
Merge pull request #8 from littledgg/mtp
gongshaotian Sep 23, 2025
678152f
Solve problem for Draft model in cudagraph
littledgg Sep 23, 2025
d841cc6
Solve comment
littledgg Sep 23, 2025
3bf990c
Merge pull request #9 from littledgg/mtp
gongshaotian Sep 24, 2025
2eaf778
replace emptytensor as zeros
Sep 24, 2025
24fa8cb
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
1a4190b
Solve comments
littledgg Sep 24, 2025
d3e7df9
Revert "fast compile"
littledgg Sep 24, 2025
96d85a0
Merge pull request #10 from littledgg/mtp
gongshaotian Sep 24, 2025
1c23a3e
merge develop
littledgg Sep 24, 2025
f814026
Merge pull request #11 from littledgg/mtp
gongshaotian Sep 24, 2025
beaaaec
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
ce11adb
fix bug
littledgg Sep 24, 2025
4c06088
Merge pull request #12 from littledgg/mtp
gongshaotian Sep 24, 2025
c885ba6
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
d23206c
fix merge bug
Sep 24, 2025
4a8f947
fix typo
Sep 25, 2025
2137520
fix bug
Sep 25, 2025
a4323aa
merge develop
Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,9 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意下编码规范

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意下编码规范

这里能把 bid 从 uint32_t 切换成 int 吗?取值范围变小了有无风险?

continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
Expand All @@ -2437,6 +2440,8 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

continue;
}

using LoadT = AlignedVector<T, vec_size>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

Expand Down Expand Up @@ -390,15 +382,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

Expand Down Expand Up @@ -525,15 +509,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts);

void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
Expand Down
1 change: 0 additions & 1 deletion custom_ops/gpu_ops/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ std::vector<paddle::Tensor> rebuild_padding(
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;

if (output_padding_offset) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets"
"token_num",
"seq_len",
"seq_lens_encoder"})
Expand Down
32 changes: 25 additions & 7 deletions custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ __global__ void speculate_verify(
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode) {
const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
Expand Down Expand Up @@ -101,6 +101,24 @@ __global__ void speculate_verify(
if (seq_lens_encoder[bid] != 0) {
break;
}
if (accept_all_drafts) {
// accept all draft tokens
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;

if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
break;
} else {
accept_num_now++;
}
continue;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
Expand Down Expand Up @@ -249,7 +267,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
Expand Down Expand Up @@ -292,7 +310,7 @@ void SpeculateVerify(
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode);
prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
Expand All @@ -308,7 +326,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
} else {
if (enable_topp) {
Expand All @@ -326,7 +344,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
Expand All @@ -342,7 +360,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
}

Expand All @@ -357,7 +375,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,11 @@ def __init__(

if self.graph_opt_config.cudagraph_only_prefill:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
elif self.speculative_config.method == "mtp":
max_shape = self.parallel_config.max_num_seqs * (self.speculative_config.num_model_steps + 1)
if max_shape % 2 == 1:
max_shape = max_shape + 1
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=min(512, max_shape))
else:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.scheduler_config.max_num_seqs)

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def format_str(obj):
"shape": obj.shape,
"dtype": str(obj.dtype),
"place": str(obj.place),
# "content": obj if obj.numel()<10 else "Too big to show"
"content": obj if obj.numel() < 70 else "Too big to show",
}
return tensor_info
elif isinstance(obj, (list, tuple)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
f"[CUDA GRAPH][ID:{id(self)}] Warm up for batch size {entry.real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)

Expand All @@ -138,15 +138,15 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
f"[CUDA GRAPH][ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}"
)

entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}")

if not entry.use_cudagraph:
return entry.runnable(**kwargs)
Expand All @@ -161,7 +161,7 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"[CUDA GRAPH][ID:{id(self)}] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)

Expand Down Expand Up @@ -196,11 +196,11 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:

# For CUDAGraph debug
# self._save_cudagrpah_dot_files(entry)
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}")

# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}")
if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
Expand All @@ -213,8 +213,9 @@ def _create_entry_dict(self):
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)

logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
logger.debug(
f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
"Created all real shape entry."
)

def clear_graph(self):
Expand All @@ -223,7 +224,7 @@ def clear_graph(self):
for id, entry in self.concrete_size_entries.items():
if entry.cuda_graph:
del entry.cuda_graph
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] The CUDAGraph with shape {id} has been cleared.")

del self.concrete_size_entries
paddle.device.cuda.empty_cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, runnable: Callable, fd_config: FDConfig):
self.runnable = runnable
self.fd_config = fd_config

self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda graph input buffers (contain output of subgraphs)

Expand All @@ -138,9 +138,9 @@ def __call__(self, **kwargs):
)

assert kwargs["forward_meta"].ids_remove_padding is not None
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]

if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
return self.runnable(**kwargs)
else:
return self.cudagraph_piecewise_backend.__call__(**kwargs)
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def forward_cuda(
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
) -> paddle.Tensor:
""" """

Expand Down Expand Up @@ -517,6 +518,7 @@ def forward_cuda(
self.speculative_verify_window,
True, # enable_topp
self.speculative_benchmark_mode,
accept_all_drafts,
)

return None
Expand Down
10 changes: 9 additions & 1 deletion fastdeploy/model_executor/models/ernie4_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
Expand Down Expand Up @@ -234,6 +237,7 @@ def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, moe_layer_st
return mappings


@support_graph_optimization
class Ernie4_5_MTPModel(nn.Layer):
"""
Ernie4_5_MTPModel
Expand Down Expand Up @@ -457,6 +461,10 @@ def forward(
"""
forward
"""
hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta)
hidden_states = self.ernie(
ids_remove_padding=ids_remove_padding,
previous_hidden_states=previous_hidden_states,
forward_meta=forward_meta,
)

return hidden_states
27 changes: 14 additions & 13 deletions fastdeploy/spec_decode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,25 @@ class Proposer(ABC):
the speculative decoding framework
"""

def __init__(self, cfg: FDConfig):
def __init__(self, fd_config: FDConfig):
"""
Init Speculative proposer
"""
cfg.parallel_config.tp_group = None
self.cfg = deepcopy(cfg)
cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
fd_config.parallel_config.tp_group = None
self.fd_config = deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.cfg.parallel_config
self.model_config = self.cfg.model_config
self.speculative_config = self.cfg.speculative_config
self.cache_config = self.cfg.cache_config
self.quant_config = self.cfg.quant_config
self.scheduler_config = self.cfg.scheduler_config
self.parallel_config = self.fd_config.parallel_config
self.model_config = self.fd_config.model_config
self.speculative_config = self.fd_config.speculative_config
self.cache_config = self.fd_config.cache_config
self.quant_config = self.fd_config.quant_config
self.graph_opt_config = self.fd_config.graph_opt_config
self.scheduler_config = self.fd_config.scheduler_config

self.max_num_seqs = self.scheduler_config.max_num_seqs
self.max_model_len = self.parallel_config.max_model_len
Expand Down
Loading
Loading