Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
Expand All @@ -90,7 +90,8 @@ void AppendAttentionKernel(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
Expand Down Expand Up @@ -146,6 +147,7 @@ void AppendAttentionKernel(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
sinks,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
Expand All @@ -169,7 +171,8 @@ void AppendAttentionKernel(
lambda_is_decoder,
lambda_enable_prefill,
lambda_stream,
&fmha_out);
&fmha_out,
sliding_window);
};

if (max_enc_len_this_time > 0) {
Expand Down Expand Up @@ -428,6 +431,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -443,7 +447,8 @@ std::vector<paddle::Tensor> AppendAttention(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
AppendAttnMetaData meta_data;

const auto& qkv_dims = qkv.dims();
Expand Down Expand Up @@ -550,10 +555,10 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
Expand All @@ -568,7 +573,8 @@ std::vector<paddle::Tensor> AppendAttention(
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
sliding_window);
};


Expand Down Expand Up @@ -630,6 +636,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& sinks,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -645,7 +652,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
AppendAttnMetaData meta_data;

const auto& qkv_dims = qkv.dims();
Expand Down Expand Up @@ -704,10 +712,10 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
Expand All @@ -722,7 +730,8 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
speculate_decoder,
sliding_window);
};

phi::dtype::float16 fp16_dtype;
Expand Down Expand Up @@ -797,6 +806,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -812,7 +822,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
int head_dim = key_cache_shape[3];
Expand Down Expand Up @@ -860,6 +871,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -875,7 +887,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
Expand Down Expand Up @@ -942,6 +955,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& sinks_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -957,7 +971,8 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
return {fmha_out_shape};
}

Expand Down Expand Up @@ -998,6 +1013,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const paddle::optional<paddle::DataType>& sinks_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
Expand All @@ -1013,7 +1029,8 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const bool speculate_decoder,
const int sliding_window) {
return {fmha_out_dtype};
}

Expand Down Expand Up @@ -1054,7 +1071,8 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
.Outputs({"fmha_out"})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
Expand All @@ -1072,6 +1090,7 @@ PD_BUILD_STATIC_OP(append_attention)
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
Expand Down Expand Up @@ -1113,7 +1132,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
paddle::Optional("k_norm_weight"),
paddle::Optional("sinks")})
.Outputs({"fmha_out_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
.Attrs({"rms_norm_eps: float",
Expand All @@ -1132,6 +1152,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"sliding_window: int",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
Expand Down
39 changes: 31 additions & 8 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void CascadeAppendAttentionC16Kernel(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -58,7 +60,8 @@ void CascadeAppendAttentionC16Kernel(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out) {
paddle::Tensor* out,
const int sliding_window) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
Expand Down Expand Up @@ -99,6 +102,7 @@ void CascadeAppendAttentionC16Kernel(
attn_mask,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
Expand All @@ -118,7 +122,8 @@ void CascadeAppendAttentionC16Kernel(
speculate_max_draft_token_num,
is_decoder,
stream,
out);
out,
sliding_window);
})})})})})})
}

Expand All @@ -142,6 +147,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -164,7 +171,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
Expand All @@ -186,6 +194,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -208,7 +218,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const AppendAttnMetaData& meta_data,
Expand All @@ -230,6 +241,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -252,7 +265,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
Expand All @@ -274,6 +288,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -296,7 +312,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
const AppendAttnMetaData& meta_data,
Expand All @@ -318,6 +335,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -340,7 +359,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);

template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const AppendAttnMetaData& meta_data,
Expand All @@ -362,6 +382,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
Expand All @@ -384,4 +406,5 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
paddle::Tensor* out,
const int sliding_window);
Loading
Loading