Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions custom_ops/metax_ops/mc_fused_moe_helper.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fused_moe_helper.h"
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,70 @@ def get_kv_cache_shape(
self.kv_lora_rank + self.qk_rope_head_dim,
)

def compute_flash_mla(
self,
query: paddle.Tensor,
latent_cache: paddle.Tensor,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
from flash_mla_paddle import flash_mla_with_kvcache, get_mla_metadata

assert latent_cache is not None

seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
seq_lens_decoder = seq_lens_decoder[non_zero_index]
seq_lens_this_time = seq_lens_this_time[non_zero_index]

latent_cache = latent_cache.transpose([0, 2, 1, 3])
block_tables = self.attention_metadata.block_tables[non_zero_index]
seq_lens = seq_lens_decoder + seq_lens_this_time
batch_size = block_tables.shape[0]
seq_len_q = seq_lens_this_time.max()
num_heads_q = self.num_heads
num_heads_kv = latent_cache.shape[2]
head_dim_v = self.kv_lora_rank
head_dim_qk = self.kv_lora_rank + self.qk_rope_head_dim

if seq_len_q != seq_lens_this_time.min():
x = query.split(list(seq_lens_this_time), axis=0)
x = [paddle.concat([xi, paddle.zeros((seq_len_q - xi.shape[0], xi.shape[1]))], axis=0) for xi in x]
query = paddle.to_tensor(x)

query = query.reshape([batch_size, seq_len_q, num_heads_q, head_dim_qk])

tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
)

if tile_scheduler_metadata.shape[0] == 0:
print(f"seq_lens: {seq_lens}")
print(f"seq_len_q: {seq_len_q}")
print(f"num_heads_q: {num_heads_q}")
print(f"num_heads_kv: {num_heads_kv}")
assert tile_scheduler_metadata.shape[0] != 0

out = flash_mla_with_kvcache(
query,
latent_cache,
block_tables,
seq_lens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
softmax_scale=self.attn_softmax_scale,
causal=self.causal,
)[0]

if seq_len_q != seq_lens_this_time.min():
x = [xi.split([num, seq_len_q - num], axis=0)[0] for xi, num in zip(out, seq_lens_this_time)]
out = paddle.concat(x, axis=0)
else:
out = out.reshape([-1, num_heads_q, head_dim_v])

return out

def forward_extend(
self,
q: paddle.Tensor,
Expand Down Expand Up @@ -258,81 +322,6 @@ def forward_extend(

return fmha_out

def _run_single_flash_mla(self, query, latent_cache, block_tables, seq_lens, draft_token_num):
from flash_mla_paddle import flash_mla_with_kvcache, get_mla_metadata

qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
v_head_dim = self.kv_lora_rank
q_head_num = self.num_heads
kv_head_num = latent_cache.shape[2]

query = query.reshape([-1, draft_token_num, q_head_num, qk_head_dim])
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens, draft_token_num * q_head_num // kv_head_num, kv_head_num
)

out, _ = flash_mla_with_kvcache(
query,
latent_cache,
block_tables,
seq_lens,
v_head_dim,
tile_scheduler_metadata,
num_splits,
softmax_scale=self.attn_softmax_scale,
causal=True,
)

return out.reshape([-1, q_head_num, v_head_dim])

def compute_flash_mla(self, query, latent_cache, forward_meta):
block_tables = self.attention_metadata.block_tables
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
assert block_tables is not None and seq_lens_decoder is not None and seq_lens_this_time is not None
assert block_tables.shape[0] == seq_lens_decoder.shape[0]

query = query.reshape([-1, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim])
latent_cache = latent_cache.transpose([0, 2, 1, 3])

seq_lens_decoder = seq_lens_decoder.squeeze(-1)
seq_lens_this_time = seq_lens_this_time.squeeze(-1)
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
seq_lens_decoder = seq_lens_decoder[non_zero_index]
seq_lens_this_time = seq_lens_this_time[non_zero_index]
block_tables = block_tables[non_zero_index]

max_seq_lens_this_time = seq_lens_this_time.max().item()
min_seq_lens_this_time = seq_lens_this_time.min().item()

if max_seq_lens_this_time == min_seq_lens_this_time:
return self._run_single_flash_mla(
query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_seq_lens_this_time
)
else:
max_draft_token_num = self.speculate_max_draft_token_num + 1
seq_lens_this_time_cpu = seq_lens_this_time.cpu()
bsz = seq_lens_this_time_cpu.shape[0]
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
batched_query = paddle.zeros(
[bsz * max_draft_token_num, self.num_heads, qk_head_dim], dtype=query.dtype
).to(query.place)
full_token_index = paddle.arange(bsz * max_draft_token_num, dtype="int32").reshape(
[bsz, max_draft_token_num]
)
token_mapping_index = []
for group_id in range(bsz):
seq_len = seq_lens_this_time_cpu[group_id]
token_mapping_index.append(full_token_index[group_id, :seq_len])
token_mapping_index = paddle.concat(token_mapping_index)
assert token_mapping_index.shape[0] == query.shape[0]
batched_query[token_mapping_index] = query
seq_lens_this_time = paddle.full_like(seq_lens_this_time, fill_value=max_draft_token_num)
out = self._run_single_flash_mla(
batched_query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_draft_token_num
)
return out[token_mapping_index]

def forward_decode(
self,
q: paddle.Tensor,
Expand Down Expand Up @@ -388,57 +377,7 @@ def forward_mixed(
"""
Mixed模式的前向传播
"""
metadata = self.attention_metadata
speculate_decoder = self.speculative_method is not None

latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None

if k is not None:
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
)

# FA
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]

return fmha_out

# Decode
if k is None:
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)

# 多头潜在注意力计算
fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta)

return fmha_out
return self.forward_extend(q, k, v, qkv, compressed_kv, k_pe, layer, forward_meta)
else:
return self.forward_decode(q, k, v, qkv, compressed_kv, k_pe, layer, forward_meta)
Loading
Loading