Skip to content

Commit d69f7d3

Browse files
committed
[Metax] adapt cutlass moe for DeepSeek
1 parent acd3317 commit d69f7d3

File tree

5 files changed

+253
-135
lines changed

5 files changed

+253
-135
lines changed

custom_ops/metax_ops/mc_fused_moe_helper.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "fused_moe_helper.h"
216
#include "mctlass/numeric_conversion.h"
317
#include "mctlassEx/mctlassEx.h"

custom_ops/setup_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ def find_end_files(directory, end_str):
607607
"gpu_ops/text_image_gather_scatter.cu",
608608
"gpu_ops/text_image_index_out.cu",
609609
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
610+
"gpu_ops/limit_thinking_content_length_v1.cu",
611+
"gpu_ops/limit_thinking_content_length_v2.cu",
610612
"gpu_ops/append_attn/mla_cache_kernel.cu",
611613
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
612614
"gpu_ops/moe/tritonmoe_preprocess.cu",

fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,15 +14,22 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
18+
1719
import paddle
1820
from paddle import nn
1921
from paddle.nn.quant import weight_quantize
2022

21-
import fastdeploy
2223
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
2324
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
25+
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
2426
from fastdeploy.model_executor.layers.utils import get_tensor
25-
from fastdeploy.model_executor.ops.gpu import fused_expert_moe
27+
from fastdeploy.model_executor.ops.gpu import (
28+
fused_expert_moe,
29+
moe_expert_dispatch,
30+
moe_expert_ffn,
31+
moe_expert_reduce,
32+
)
2633
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
2734

2835

@@ -54,7 +61,7 @@ def compute_ffn(
5461
"""
5562
Paddle Cutlass compute Fused MoE.
5663
"""
57-
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
64+
return moe_expert_ffn(
5865
permute_input,
5966
token_nums_per_expert,
6067
getattr(layer, self.added_weight_attrs[0]),
@@ -96,23 +103,62 @@ def apply_tp(
96103
"""
97104
Paddle Cutlass compute Fused MoE.
98105
"""
106+
if layer.topk_method == "noaux_tc":
107+
gate_out = gate(x.cast("float32"))
108+
109+
gate_out, topk_weights, topk_idx = get_moe_scores(
110+
gate_out,
111+
layer.n_group,
112+
layer.topk_group,
113+
layer.top_k,
114+
layer.routed_scaling_factor,
115+
layer.gate_correction_bias,
116+
getattr(layer, "renormalize", True),
117+
)
118+
119+
(
120+
permute_input,
121+
token_nums_per_expert,
122+
permute_indices_per_token,
123+
topk_weights,
124+
topk_idx,
125+
) = moe_expert_dispatch(
126+
x,
127+
gate_out,
128+
layer.top_k,
129+
False,
130+
True,
131+
)
132+
133+
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
134+
135+
fused_moe_out = moe_expert_reduce(
136+
ffn_out,
137+
topk_weights,
138+
permute_indices_per_token,
139+
topk_idx,
140+
None,
141+
False,
142+
1.0,
143+
)
144+
else:
145+
fused_moe_out = fused_expert_moe(
146+
x,
147+
gate.weight,
148+
getattr(layer, self.added_weight_attrs[0]),
149+
getattr(layer, self.added_weight_attrs[1]),
150+
None,
151+
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
152+
None,
153+
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
154+
"weight_only_int8",
155+
layer.top_k,
156+
True,
157+
False,
158+
)
99159

100-
fused_moe_out = fused_expert_moe(
101-
x,
102-
gate.weight,
103-
getattr(layer, self.added_weight_attrs[0]),
104-
getattr(layer, self.added_weight_attrs[1]),
105-
None,
106-
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
107-
None,
108-
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
109-
"weight_only_int8",
110-
layer.top_k,
111-
True,
112-
False,
113-
)
114160
if layer.reduce_results and layer.tp_size > 1:
115-
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
161+
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
116162

117163
return fused_moe_out
118164

@@ -122,15 +168,14 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
122168
weight only for moe
123169
"""
124170

125-
def __init__(self, quant_config=None):
126-
"""
127-
weight only for moe
128-
"""
171+
def __init__(self, quant_config):
129172
super().__init__(quant_config)
130-
# print(f"[DEBUG] quant_config: {quant_config}")
131173
self.quant_config = quant_config
132174
self.moe_quant_type = self.quant_config.algo
133175
self.pack_num = 1
176+
self.weight_only_linear_arch = os.getenv("FLAGS_weight_only_linear_arch")
177+
if self.weight_only_linear_arch is not None:
178+
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
134179

135180
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
136181
"""
@@ -200,20 +245,20 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
200245
]
201246
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
202247
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
203-
204-
if layer.fd_config.load_config.load_choices == "default_v1":
248+
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
249+
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
205250
layer.up_gate_proj_weight = layer.create_parameter(
206-
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
251+
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
207252
dtype=layer.weight_dtype,
208253
default_initializer=paddle.nn.initializer.Constant(0),
209254
)
210255

211256
layer.down_proj_weight = layer.create_parameter(
212-
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
257+
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
213258
dtype=layer.weight_dtype,
214259
default_initializer=paddle.nn.initializer.Constant(0),
215260
)
216-
261+
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
217262
set_weight_attrs(
218263
layer.up_gate_proj_weight,
219264
{
@@ -273,7 +318,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
273318
default_initializer=paddle.nn.initializer.Constant(0),
274319
),
275320
)
276-
321+
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
277322
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
278323
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
279324
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
@@ -286,7 +331,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
286331

287332
def process_weights_after_loading(self, layer):
288333
""" """
289-
if not layer.fd_config.load_config.load_choices == "default_v1":
334+
if not self.quant_config.is_checkpoint_bf16:
290335
return
291336
weight_id_map = {"gate_up": 0, "down": 1}
292337
if (
@@ -316,9 +361,11 @@ def process_weights_after_loading(self, layer):
316361

317362
# 3.quantize weight
318363

319-
for expert_id in range(layer.num_experts):
364+
for expert_id in range(layer.num_local_experts):
320365
weight[expert_id], scale[expert_id] = weight_quantize(
321-
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type, arch=80, group_size=-1
366+
getattr(layer, unquantized_weight_name)[expert_id],
367+
algo=self.moe_quant_type,
368+
arch=self.weight_only_linear_arch,
322369
)
323370

324371
free_tensor(getattr(layer, unquantized_weight_name))
@@ -360,7 +407,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
360407
weight_scale_list = []
361408
for i in range(layer.num_local_experts):
362409
quant_weight, scale = weight_quantize(
363-
weight_tensor[i], algo=self.moe_quant_type, arch=80, group_size=-1
410+
weight_tensor[i], algo=self.moe_quant_type, arch=self.weight_only_linear_arch
364411
)
365412
quant_weight = paddle.transpose(quant_weight, [1, 0])
366413
weight_list.append(quant_weight)

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@
5050
elif current_platform.is_maca():
5151
from fastdeploy.model_executor.ops.gpu import (
5252
get_padding_offset,
53+
limit_thinking_content_length_v1,
54+
limit_thinking_content_length_v2,
5355
save_output,
5456
set_stop_value_multi_ends,
57+
speculate_limit_thinking_content_length_v1,
58+
speculate_limit_thinking_content_length_v2,
5559
step_paddle,
5660
update_inputs,
5761
update_inputs_v1,
@@ -770,7 +774,9 @@ def rebuild_padding(
770774
seq_lens_decoder,
771775
seq_lens_encoder,
772776
output_padding_offset,
777+
first_token_out,
773778
max_input_length,
779+
enable_logprob,
774780
)
775781
else:
776782
raise RuntimeError("Not supported platform")

0 commit comments

Comments
 (0)