Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const float rms_norm_eps,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
Expand Down Expand Up @@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
Expand All @@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);

if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
Expand Down Expand Up @@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
rms_norm_eps,
rope_3d);
}
}

Expand Down Expand Up @@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}

if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
Expand Down Expand Up @@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel(
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps);
rms_norm_eps,
rope_3d);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}

} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
Expand Down
Loading