Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a135ab5
delete dynamic control flow for decode
aquagull Jul 30, 2025
4add5d6
coda-style
aquagull Jul 30, 2025
755ab26
Merge branch 'PaddlePaddle:develop' into cuda_graph_vl
aquagull Jul 31, 2025
5ac10d9
fix scatter/gather typos and use input stream instead default stream
aquagull Aug 5, 2025
533856c
support 0-Size Tensor
aquagull Aug 5, 2025
4d64786
update runner and model
aquagull Aug 5, 2025
7529cfc
using static mem address as input
aquagull Aug 7, 2025
2e2a40d
fix mem leak
aquagull Aug 7, 2025
07675c5
refine code
aquagull Aug 7, 2025
8805721
update mm_buffer
aquagull Aug 8, 2025
8d75410
Merge branch 'develop' into cuda_graph_vl
aquagull Aug 8, 2025
75770a6
fix typo
Aug 8, 2025
e9ccc26
fix buffersize
Aug 11, 2025
5f14f13
Merge branch 'develop' into cuda_graph_vl
aquagull Aug 14, 2025
571e9d7
Merge branch 'cuda_graph_vl' of https://github.com/aquagull/FastDeplo…
Aug 15, 2025
3c98efc
Merge branch 'develop' into cuda_graph_vl
aquagull Aug 15, 2025
5fdd288
Merge remote-tracking branch 'paddle/develop' into cuda_graph_vl
aquagull Aug 19, 2025
5555161
fix unk token
aquagull Aug 21, 2025
8580d1d
Merge branch 'develop' into cuda_graph_vl
aquagull Aug 21, 2025
44f3718
refine code
Aug 22, 2025
e3a5389
refine code
Aug 22, 2025
f2b0f1f
refine
Aug 25, 2025
90177cf
support other arch
aquagull Aug 25, 2025
0f26ee4
fix toekn_type_ids buffer padding
Aug 26, 2025
03be126
open cudagraph in vlci
Aug 26, 2025
72e0fa9
fix
Aug 26, 2025
e9fd71e
Merge remote-tracking branch 'paddle/develop' into cuda_graph_vl
Aug 28, 2025
d067cd9
update
Aug 29, 2025
2994949
update
Aug 29, 2025
0b3aa07
update
Aug 29, 2025
f6d9c0c
Merge branch 'develop' into cuda_graph_vl
aquagull Sep 3, 2025
e63ad17
fix cmd
aquagull Sep 3, 2025
fe077b6
Merge branch 'develop' into cuda_graph_vl
aquagull Sep 3, 2025
b0b8cf3
update
aquagull Sep 8, 2025
9050777
Merge branch 'develop' into cuda_graph_vl
aquagull Sep 9, 2025
db67cd2
Merge branch 'develop' into cuda_graph_vl
yuanlehome Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);

void TextImageIndexOut(const paddle::Tensor &token_type_ids,
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
paddle::Tensor &text_input,
paddle::Tensor &image_input);

void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/get_padding_offset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}

PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
Expand Down
12 changes: 12 additions & 0 deletions custom_ops/gpu_ops/moe/moe_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ void MoeDispatchKernel(
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
using namespace phi;

if (num_rows == 0){
return;
}
Comment on lines +39 to +41
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c++代码最好也格式化一下

typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
Expand Down Expand Up @@ -170,6 +173,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto expert_idx_per_token =
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);

if (token_rows == 0){
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
topk_idx,
expert_idx_per_token};
}

switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
Expand Down
4 changes: 3 additions & 1 deletion custom_ops/gpu_ops/moe/moe_ffn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);

if(permute_input.numel() == 0){
return ffn_out;
}
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
Expand Down
4 changes: 4 additions & 0 deletions custom_ops/gpu_ops/moe/moe_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(

auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);

if(num_rows == 0){
return output;
}

switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
Expand Down
24 changes: 15 additions & 9 deletions custom_ops/gpu_ops/text_image_gather_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec input_ptr_vec;
T_Vec text_imgaes_vec;
T_Vec text_images_vec;

int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
Expand All @@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll
for(int vi = 0; vi < VecSize; ++vi) {
text_imgaes_vec[vi] = input_ptr_vec[vi];
text_images_vec[vi] = input_ptr_vec[vi];
}

if (token_type_ids_num == 0) {
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);

} else {
} else if(token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);

} else {
// skip cuda graph padding value
continue;
}
}
}
Expand Down Expand Up @@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);

} else {
} else if (token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
} else {
// skip cuda graph padding value
continue;
}

#pragma unroll
Expand Down Expand Up @@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
const int64_t token_num = in_dims[0];
const int64_t hidden_size = in_dims[1];


const int VecSize = 16 / sizeof(data_t);
const int64_t tot_element_num = token_num * hidden_size;

Expand All @@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
dim3 grid_dim = dim3(grid_size_x, 1, 1);
if (is_scatter) {
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
Expand All @@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
tot_element_num
);
} else {
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
Expand Down
20 changes: 12 additions & 8 deletions custom_ops/gpu_ops/text_image_index_out.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

template <int VecSize>
__global__ void text_image_index_out_kernel(
int32_t* token_type_ids,
const int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t token_num
Expand All @@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
text_count += 1;
} else {
} else if (token_type_ids[i] == 1) {
image_index[i] = images_count;
images_count += 1;
} else {
// skip cuda graph padding value
continue;
}
}
}

void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
paddle::Tensor& text_index,
paddle::Tensor& image_index) {

const int64_t token_num = token_type_ids.shape()[0];
text_image_index_out_kernel<1><<<1, 1>>>(
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
auto stream = token_type_ids.stream();
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
token_type_ids.data<int32_t>(),
text_index.data<int32_t>(),
image_index.data<int32_t>(),
token_num
);
}
Expand Down
32 changes: 32 additions & 0 deletions fastdeploy/model_executor/graph_optimization/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,35 @@ def clear_grpah_opt_backend(self, fd_config):
fd_config.graph_opt_config.graph_opt_level < 1
), "Currently unable to update weights in static graph mode."
self.graph_opt_backend.clear_cudagraph_piecewise_backend()


def mm_buffer(buffer_meta):
def decorator(cls):
original_init = cls.__init__

def __init__(self, fd_config: FDConfig, **kwargs):
original_init(self, fd_config=fd_config, **kwargs)

def _resolve_path(root, path: str):
cur = root
for p in path.split("."):
cur = getattr(cur, p)
return cur

if not hasattr(self, "_mm_buffers"):
self._mm_buffers = {}
for name, meta in buffer_meta.items():
shape = [_resolve_path(fd_config, s) if isinstance(s, str) else s for s in meta["shape"]]
dtype = meta["dtype"]
if "." in meta["dtype"]:
dtype = _resolve_path(fd_config, meta["dtype"])
self._mm_buffers[name] = paddle.full(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mm_buffers这个变量名也改一下

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape=shape,
dtype=dtype,
fill_value=meta.get("value", 0),
)

cls.__init__ = __init__
return cls

return decorator
Loading