|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +#pragma once |
| 15 | + |
| 16 | +#include "helper.h" |
| 17 | +#include "multiquery_decoder_attention_kernel.h" |
| 18 | +#include "utils.cuh" |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +void DecodeMLAAttentionKernel( |
| 22 | + const AppendAttnMetaData &meta_data, |
| 23 | + const paddle::Tensor &q, // [token_num, num_heads, head_dim] |
| 24 | + const paddle::Tensor &cache_k, |
| 25 | + const paddle::Tensor &cache_v, |
| 26 | + const paddle::optional<paddle::Tensor> &attn_mask, |
| 27 | + const paddle::optional<paddle::Tensor> &shift_bias, |
| 28 | + const paddle::optional<paddle::Tensor> &smooth_weight, |
| 29 | + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 |
| 30 | + const paddle::Tensor &seq_lens_kv, |
| 31 | + const paddle::Tensor &batch_id_per_token, |
| 32 | + const paddle::Tensor &cu_seqlens_q, |
| 33 | + const paddle::Tensor &block_table, |
| 34 | + int max_seq_len, |
| 35 | + int max_dec_len, |
| 36 | + float softmax_scale, |
| 37 | + float in_scale, |
| 38 | + bool causal, |
| 39 | + cudaStream_t &stream, |
| 40 | + paddle::Tensor *out) { |
| 41 | + const auto token_num = meta_data.token_nums; |
| 42 | + const auto block_size = meta_data.block_size; |
| 43 | + const auto bsz = meta_data.batch_size; |
| 44 | + const auto num_heads = meta_data.q_num_heads; |
| 45 | + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; |
| 46 | + const auto head_dim_qk = meta_data.head_dims; |
| 47 | + const auto head_dim_v = meta_data.head_dims_v; |
| 48 | + const float rope_scale = 0.0; |
| 49 | + const float rope_theta = 0.0; |
| 50 | + const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); |
| 51 | + const uint32_t num_stage = get_cascade_attention_num_stages(); |
| 52 | + const uint32_t num_threads = get_cascade_attention_num_threads(); |
| 53 | + |
| 54 | + DISPATCH_CAUSAL( |
| 55 | + causal, |
| 56 | + CAUSAL, |
| 57 | + {DISPATCH_MLA_GROUP_SIZE( |
| 58 | + group_size, |
| 59 | + GROUP_SIZE, |
| 60 | + {DISPATCH_MLA_HEAD_DIM( |
| 61 | + head_dim_qk, |
| 62 | + HEAD_DIM_QK, |
| 63 | + {DISPATCH_MLA_HEAD_DIM( |
| 64 | + head_dim_v, |
| 65 | + HEAD_DIM_V, |
| 66 | + {DISPATCH_BLOCK_SIZE( |
| 67 | + block_size, |
| 68 | + BLOCK_SIZE, |
| 69 | + {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, { |
| 70 | + MultiQueryDecoderAttention<T, |
| 71 | + GROUP_SIZE, |
| 72 | + HEAD_DIM_QK, |
| 73 | + HEAD_DIM_V, |
| 74 | + BLOCK_SIZE, |
| 75 | + CAUSAL, |
| 76 | + 2, |
| 77 | + 16, |
| 78 | + DEAL_EACH_TIME>( |
| 79 | + meta_data, |
| 80 | + stream, |
| 81 | + q, |
| 82 | + cache_k, |
| 83 | + cache_v, |
| 84 | + attn_mask, |
| 85 | + shift_bias, |
| 86 | + smooth_weight, |
| 87 | + seq_lens_q, |
| 88 | + seq_lens_kv, |
| 89 | + batch_id_per_token, |
| 90 | + cu_seqlens_q, |
| 91 | + block_table, |
| 92 | + max_seq_len, |
| 93 | + max_dec_len, |
| 94 | + rope_scale, |
| 95 | + rope_theta, |
| 96 | + softmax_scale, |
| 97 | + in_scale, |
| 98 | + out); |
| 99 | + })})})})})}); |
| 100 | +} |
| 101 | + |
| 102 | +template void DecodeMLAAttentionKernel<paddle::bfloat16>( |
| 103 | + const AppendAttnMetaData &meta_data, |
| 104 | + const paddle::Tensor &q, // [token_num, num_heads, head_dim] |
| 105 | + const paddle::Tensor &cache_k, |
| 106 | + const paddle::Tensor &cache_v, |
| 107 | + const paddle::optional<paddle::Tensor> &attn_mask, |
| 108 | + const paddle::optional<paddle::Tensor> &shift_bias, |
| 109 | + const paddle::optional<paddle::Tensor> &smooth_weight, |
| 110 | + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 |
| 111 | + const paddle::Tensor &seq_lens_kv, |
| 112 | + const paddle::Tensor &batch_id_per_token, |
| 113 | + const paddle::Tensor &cu_seqlens_q, |
| 114 | + const paddle::Tensor &block_table, |
| 115 | + int max_seq_len, |
| 116 | + int max_dec_len, |
| 117 | + float softmax_scale, |
| 118 | + float in_scale, |
| 119 | + bool causal, |
| 120 | + cudaStream_t &stream, |
| 121 | + paddle::Tensor *out); |
| 122 | + |
| 123 | +template void DecodeMLAAttentionKernel<paddle::float16>( |
| 124 | + const AppendAttnMetaData &meta_data, |
| 125 | + const paddle::Tensor &q, // [token_num, num_heads, head_dim] |
| 126 | + const paddle::Tensor &cache_k, |
| 127 | + const paddle::Tensor &cache_v, |
| 128 | + const paddle::optional<paddle::Tensor> &attn_mask, |
| 129 | + const paddle::optional<paddle::Tensor> &shift_bias, |
| 130 | + const paddle::optional<paddle::Tensor> &smooth_weight, |
| 131 | + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 |
| 132 | + const paddle::Tensor &seq_lens_kv, |
| 133 | + const paddle::Tensor &batch_id_per_token, |
| 134 | + const paddle::Tensor &cu_seqlens_q, |
| 135 | + const paddle::Tensor &block_table, |
| 136 | + int max_seq_len, |
| 137 | + int max_dec_len, |
| 138 | + float softmax_scale, |
| 139 | + float in_scale, |
| 140 | + bool causal, |
| 141 | + cudaStream_t &stream, |
| 142 | + paddle::Tensor *out); |
0 commit comments