11"""
2- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
1414# limitations under the License.
1515"""
1616
17+ import os
18+
1719import paddle
1820from paddle import nn
1921from paddle .nn .quant import weight_quantize
2022
21- import fastdeploy
2223from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
2324from fastdeploy .model_executor .layers .moe .fused_moe_backend_base import MoEMethodBase
25+ from fastdeploy .model_executor .layers .moe .moe import get_moe_scores
2426from fastdeploy .model_executor .layers .utils import get_tensor
25- from fastdeploy .model_executor .ops .gpu import fused_expert_moe
27+ from fastdeploy .model_executor .ops .gpu import (
28+ fused_expert_moe ,
29+ moe_expert_dispatch ,
30+ moe_expert_ffn ,
31+ moe_expert_reduce ,
32+ )
2633from fastdeploy .model_executor .utils import TensorTracker , free_tensor , set_weight_attrs
2734
2835
@@ -54,7 +61,7 @@ def compute_ffn(
5461 """
5562 Paddle Cutlass compute Fused MoE.
5663 """
57- return fastdeploy . model_executor . ops . gpu . moe_expert_ffn (
64+ return moe_expert_ffn (
5865 permute_input ,
5966 token_nums_per_expert ,
6067 getattr (layer , self .added_weight_attrs [0 ]),
@@ -96,23 +103,62 @@ def apply_tp(
96103 """
97104 Paddle Cutlass compute Fused MoE.
98105 """
106+ if layer .topk_method == "noaux_tc" :
107+ gate_out = gate (x .cast ("float32" ))
108+
109+ gate_out , topk_weights , topk_idx = get_moe_scores (
110+ gate_out ,
111+ layer .n_group ,
112+ layer .topk_group ,
113+ layer .top_k ,
114+ layer .routed_scaling_factor ,
115+ layer .gate_correction_bias ,
116+ getattr (layer , "renormalize" , True ),
117+ )
118+
119+ (
120+ permute_input ,
121+ token_nums_per_expert ,
122+ permute_indices_per_token ,
123+ topk_weights ,
124+ topk_idx ,
125+ ) = moe_expert_dispatch (
126+ x ,
127+ gate_out ,
128+ layer .top_k ,
129+ False ,
130+ True ,
131+ )
132+
133+ ffn_out = self .compute_ffn (layer , permute_input , token_nums_per_expert , None )
134+
135+ fused_moe_out = moe_expert_reduce (
136+ ffn_out ,
137+ topk_weights ,
138+ permute_indices_per_token ,
139+ topk_idx ,
140+ None ,
141+ False ,
142+ 1.0 ,
143+ )
144+ else :
145+ fused_moe_out = fused_expert_moe (
146+ x ,
147+ gate .weight ,
148+ getattr (layer , self .added_weight_attrs [0 ]),
149+ getattr (layer , self .added_weight_attrs [1 ]),
150+ None ,
151+ (layer .up_gate_proj_weight_scale if hasattr (layer , "up_gate_proj_weight_scale" ) else None ),
152+ None ,
153+ (layer .down_proj_weight_scale if hasattr (layer , "down_proj_weight_scale" ) else None ),
154+ "weight_only_int8" ,
155+ layer .top_k ,
156+ True ,
157+ False ,
158+ )
99159
100- fused_moe_out = fused_expert_moe (
101- x ,
102- gate .weight ,
103- getattr (layer , self .added_weight_attrs [0 ]),
104- getattr (layer , self .added_weight_attrs [1 ]),
105- None ,
106- (layer .up_gate_proj_weight_scale if hasattr (layer , "up_gate_proj_weight_scale" ) else None ),
107- None ,
108- (layer .down_proj_weight_scale if hasattr (layer , "down_proj_weight_scale" ) else None ),
109- "weight_only_int8" ,
110- layer .top_k ,
111- True ,
112- False ,
113- )
114160 if layer .reduce_results and layer .tp_size > 1 :
115- tensor_model_parallel_all_reduce (fused_moe_out , layer .fd_config .parallel_config .tp_group )
161+ fused_moe_out = tensor_model_parallel_all_reduce (fused_moe_out , layer .fd_config .parallel_config .tp_group )
116162
117163 return fused_moe_out
118164
@@ -122,15 +168,14 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
122168 weight only for moe
123169 """
124170
125- def __init__ (self , quant_config = None ):
126- """
127- weight only for moe
128- """
171+ def __init__ (self , quant_config ):
129172 super ().__init__ (quant_config )
130- # print(f"[DEBUG] quant_config: {quant_config}")
131173 self .quant_config = quant_config
132174 self .moe_quant_type = self .quant_config .algo
133175 self .pack_num = 1
176+ self .weight_only_linear_arch = os .getenv ("FLAGS_weight_only_linear_arch" )
177+ if self .weight_only_linear_arch is not None :
178+ self .weight_only_linear_arch = int (self .weight_only_linear_arch )
134179
135180 def process_prequanted_weights (self , layer : nn .Layer , state_dict , is_rearrange : bool = False ):
136181 """
@@ -200,20 +245,20 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
200245 ]
201246 self .up_gate_proj_scale_shape = [layer .num_local_experts , layer .moe_intermediate_size * 2 ]
202247 self .down_proj_scale_shape = [layer .num_local_experts , layer .hidden_size ]
203-
204- if layer .fd_config .load_config .load_choices == "default_v1" :
248+ # TODO(bukejiyu): remove v1 loader check when v0 loader is removed
249+ if self . quant_config . is_checkpoint_bf16 and layer .fd_config .load_config .load_choices == "default_v1" :
205250 layer .up_gate_proj_weight = layer .create_parameter (
206- shape = [layer .num_experts , layer .hidden_size , layer .moe_intermediate_size * 2 ],
251+ shape = [layer .num_local_experts , layer .hidden_size , layer .moe_intermediate_size * 2 ],
207252 dtype = layer .weight_dtype ,
208253 default_initializer = paddle .nn .initializer .Constant (0 ),
209254 )
210255
211256 layer .down_proj_weight = layer .create_parameter (
212- shape = [layer .num_experts , layer .moe_intermediate_size , layer .hidden_size ],
257+ shape = [layer .num_local_experts , layer .moe_intermediate_size , layer .hidden_size ],
213258 dtype = layer .weight_dtype ,
214259 default_initializer = paddle .nn .initializer .Constant (0 ),
215260 )
216-
261+ extra_weight_attrs [ "weight_need_transpose" ] = extra_weight_attrs . get ( "model_format" ) == "torch"
217262 set_weight_attrs (
218263 layer .up_gate_proj_weight ,
219264 {
@@ -273,7 +318,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
273318 default_initializer = paddle .nn .initializer .Constant (0 ),
274319 ),
275320 )
276-
321+ extra_weight_attrs [ "weight_need_transpose" ] = not extra_weight_attrs . get ( "model_format" ) == "torch"
277322 moe_extra_weight_attrs = {** extra_weight_attrs , "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 }}
278323 set_weight_attrs (layer .up_gate_proj_weight , moe_extra_weight_attrs )
279324 set_weight_attrs (layer .down_proj_weight , moe_extra_weight_attrs )
@@ -286,7 +331,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
286331
287332 def process_weights_after_loading (self , layer ):
288333 """ """
289- if not layer . fd_config . load_config . load_choices == "default_v1" :
334+ if not self . quant_config . is_checkpoint_bf16 :
290335 return
291336 weight_id_map = {"gate_up" : 0 , "down" : 1 }
292337 if (
@@ -316,9 +361,11 @@ def process_weights_after_loading(self, layer):
316361
317362 # 3.quantize weight
318363
319- for expert_id in range (layer .num_experts ):
364+ for expert_id in range (layer .num_local_experts ):
320365 weight [expert_id ], scale [expert_id ] = weight_quantize (
321- getattr (layer , unquantized_weight_name )[expert_id ], algo = self .moe_quant_type , arch = 80 , group_size = - 1
366+ getattr (layer , unquantized_weight_name )[expert_id ],
367+ algo = self .moe_quant_type ,
368+ arch = self .weight_only_linear_arch ,
322369 )
323370
324371 free_tensor (getattr (layer , unquantized_weight_name ))
@@ -360,7 +407,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
360407 weight_scale_list = []
361408 for i in range (layer .num_local_experts ):
362409 quant_weight , scale = weight_quantize (
363- weight_tensor [i ], algo = self .moe_quant_type , arch = 80 , group_size = - 1
410+ weight_tensor [i ], algo = self .moe_quant_type , arch = self . weight_only_linear_arch
364411 )
365412 quant_weight = paddle .transpose (quant_weight , [1 , 0 ])
366413 weight_list .append (quant_weight )
0 commit comments