|
| 1 | +import json |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import unittest |
| 5 | + |
| 6 | +import paddle |
| 7 | +from paddle.distributed import fleet |
| 8 | + |
| 9 | +from fastdeploy.config import ( |
| 10 | + CacheConfig, |
| 11 | + FDConfig, |
| 12 | + GraphOptimizationConfig, |
| 13 | + LoadConfig, |
| 14 | + ModelConfig, |
| 15 | + ParallelConfig, |
| 16 | +) |
| 17 | +from fastdeploy.model_executor.layers.moe.moe import FusedMoE |
| 18 | +from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config |
| 19 | +from fastdeploy.scheduler import SchedulerConfig |
| 20 | +from fastdeploy.worker.worker_process import init_distributed_environment |
| 21 | +from tests.utils import OpPerformanceTester |
| 22 | + |
| 23 | +paddle.set_default_dtype("bfloat16") |
| 24 | + |
| 25 | + |
| 26 | +class FuseMoEWrapper(paddle.nn.Layer): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + model_config: ModelConfig, |
| 30 | + tp_size: int = 1, |
| 31 | + tp_rank: int = 0, |
| 32 | + ep_size: int = 1, |
| 33 | + ep_rank: int = 0, |
| 34 | + prefix: str = "layer0", |
| 35 | + nnodes: int = 1, |
| 36 | + ): |
| 37 | + super().__init__() |
| 38 | + self.model_config = model_config |
| 39 | + |
| 40 | + self.tp_size = tp_size |
| 41 | + self.ep_size = ep_size |
| 42 | + self.ep_rank = ep_rank |
| 43 | + |
| 44 | + self.prefix = prefix |
| 45 | + self.fd_config = FDConfig( |
| 46 | + model_config=self.model_config, |
| 47 | + parallel_config=ParallelConfig( |
| 48 | + { |
| 49 | + "tensor_parallel_size": self.tp_size, |
| 50 | + "expert_parallel_size": self.ep_size, |
| 51 | + "expert_parallel_rank": self.ep_rank, |
| 52 | + "data_parallel_size": self.ep_size, |
| 53 | + } |
| 54 | + ), |
| 55 | + quant_config=W4A8Config(is_permuted=False, hadamard_block_size=128), |
| 56 | + # quant_config=W4AFP8Config(weight_scale_dict=None, act_scale_dict=None, is_permuted=False, hadamard_block_size=128), |
| 57 | + scheduler_config=SchedulerConfig({}), |
| 58 | + cache_config=CacheConfig({}), |
| 59 | + graph_opt_config=GraphOptimizationConfig({}), |
| 60 | + load_config=LoadConfig({}), |
| 61 | + ips=",".join(["0"] * nnodes), |
| 62 | + ) |
| 63 | + self.fd_config.parallel_config.tp_group = None |
| 64 | + self.fd_config.parallel_config.tensor_parallel_rank = tp_rank |
| 65 | + self.fd_config.parallel_config.expert_parallel_size = self.ep_size |
| 66 | + if self.ep_size > 1: |
| 67 | + self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() |
| 68 | + self.fd_config.scheduler_config.splitwise_role = "mixed" |
| 69 | + self.fd_config.model_config.moe_phase.phase = "decode" |
| 70 | + |
| 71 | + weight_key_map = { |
| 72 | + "gate_weight_key": f"{self.prefix}.gate.weight", |
| 73 | + "gate_correction_bias_key": f"{self.prefix}.moe_statics.e_score_correction_bias", |
| 74 | + "up_gate_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.up_gate_proj.weight", |
| 75 | + "down_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.down_proj.weight", |
| 76 | + "up_gate_proj_expert_weight_scale_key": f"{self.prefix}.experts.{{}}.up_gate_proj.weight_scale", |
| 77 | + "down_proj_expert_weight_scale_key": f"{self.prefix}.experts.{{}}.down_proj.weight_scale", |
| 78 | + "up_gate_proj_expert_in_scale_key": f"{self.prefix}.experts.{{}}.up_gate_proj.activation_scale", |
| 79 | + "down_proj_expert_in_scale_key": f"{self.prefix}.experts.{{}}.down_proj.activation_scale", |
| 80 | + } |
| 81 | + |
| 82 | + self.fused_moe = FusedMoE( |
| 83 | + fd_config=self.fd_config, |
| 84 | + moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, |
| 85 | + num_experts=self.fd_config.model_config.moe_num_experts, |
| 86 | + top_k=self.fd_config.model_config.moe_k, |
| 87 | + # avoiding invoke clean_low_latency_buffer in mixed ep. |
| 88 | + layer_idx=666, |
| 89 | + weight_key_map=weight_key_map, |
| 90 | + topk_method="noaux_tc", |
| 91 | + topk_group=4, |
| 92 | + n_group=8, |
| 93 | + gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32), |
| 94 | + # gate_correction_bias = gate_correction_bias_real_data |
| 95 | + ) |
| 96 | + self.pack_num = 2 |
| 97 | + moe_layer = self.fused_moe |
| 98 | + |
| 99 | + up_gate_proj_weight_shape = [ |
| 100 | + moe_layer.num_local_experts, |
| 101 | + moe_layer.hidden_size // self.pack_num, |
| 102 | + moe_layer.moe_intermediate_size * 2, |
| 103 | + ] |
| 104 | + down_proj_weight_shape = [ |
| 105 | + moe_layer.num_local_experts, |
| 106 | + moe_layer.moe_intermediate_size // self.pack_num, |
| 107 | + moe_layer.hidden_size, |
| 108 | + ] |
| 109 | + up_gate_proj_weight_scale_shape = [ |
| 110 | + moe_layer.num_local_experts, |
| 111 | + moe_layer.moe_intermediate_size * 2, |
| 112 | + ] |
| 113 | + down_proj_weight_scale_shape = [ |
| 114 | + moe_layer.num_local_experts, |
| 115 | + moe_layer.hidden_size, |
| 116 | + ] |
| 117 | + |
| 118 | + up_gate_proj_weight = (paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) * 100).cast(paddle.int8) |
| 119 | + down_proj_weight = (paddle.randn(down_proj_weight_shape, paddle.bfloat16) * 100).cast(paddle.int8) |
| 120 | + |
| 121 | + up_gate_proj_weight_scale = paddle.randn(up_gate_proj_weight_scale_shape, paddle.bfloat16) |
| 122 | + down_proj_weight_scale = paddle.randn(down_proj_weight_scale_shape, paddle.bfloat16) |
| 123 | + |
| 124 | + up_gate_proj_in_scale = paddle.randn([self.fd_config.model_config.moe_num_experts, 1], paddle.float32) |
| 125 | + down_proj_in_scale = paddle.randn([self.fd_config.model_config.moe_num_experts, 1], paddle.float32) |
| 126 | + |
| 127 | + local_expert_ids = list( |
| 128 | + range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts) |
| 129 | + ) |
| 130 | + state_dict = {} |
| 131 | + up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key") |
| 132 | + up_gate_proj_expert_weight_scale_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key") |
| 133 | + up_gate_proj_expert_in_scale_key = moe_layer.weight_key_map.get("up_gate_proj_expert_in_scale_key") |
| 134 | + down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key") |
| 135 | + down_proj_expert_weight_scale_key = moe_layer.weight_key_map.get("down_proj_expert_weight_scale_key") |
| 136 | + down_proj_expert_in_scale_key = moe_layer.weight_key_map.get("down_proj_expert_in_scale_key") |
| 137 | + |
| 138 | + for expert_idx in local_expert_ids: |
| 139 | + up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx) |
| 140 | + up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx) |
| 141 | + down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) |
| 142 | + down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx) |
| 143 | + |
| 144 | + state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[ |
| 145 | + expert_idx - moe_layer.expert_id_offset |
| 146 | + ] |
| 147 | + state_dict[up_gate_proj_expert_weight_scale_key_name] = up_gate_proj_weight_scale[ |
| 148 | + expert_idx - moe_layer.expert_id_offset |
| 149 | + ] |
| 150 | + state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset] |
| 151 | + state_dict[down_proj_expert_weight_scale_key_name] = down_proj_weight_scale[ |
| 152 | + expert_idx - moe_layer.expert_id_offset |
| 153 | + ] |
| 154 | + |
| 155 | + for expert_idx in range(self.fd_config.model_config.moe_num_experts): |
| 156 | + up_gate_proj_expert_in_scale_key_name = up_gate_proj_expert_in_scale_key.format(expert_idx) |
| 157 | + down_proj_expert_in_scale_key_name = down_proj_expert_in_scale_key.format(expert_idx) |
| 158 | + state_dict[up_gate_proj_expert_in_scale_key_name] = up_gate_proj_in_scale[expert_idx] |
| 159 | + state_dict[down_proj_expert_in_scale_key_name] = down_proj_in_scale[expert_idx] |
| 160 | + |
| 161 | + moe_layer.load_state_dict(state_dict) |
| 162 | + |
| 163 | + |
| 164 | +class TestW4A8FusedMoE(unittest.TestCase): |
| 165 | + def setUp(self) -> None: |
| 166 | + self.architectures = ["Ernie4_5_MoeForCausalLM"] |
| 167 | + self.hidden_size = 8192 |
| 168 | + self.moe_intermediate_size = 3584 |
| 169 | + self.moe_num_experts = 64 |
| 170 | + self.moe_k = 8 |
| 171 | + self.hidden_act = "silu" |
| 172 | + self.num_attention_heads = 64 |
| 173 | + self.num_hidden_layers = 54 |
| 174 | + self.model_config = self.build_model_config() |
| 175 | + |
| 176 | + def build_model_config(self) -> ModelConfig: |
| 177 | + model_name_or_path = self.build_config_json() |
| 178 | + return ModelConfig( |
| 179 | + { |
| 180 | + "model": model_name_or_path, |
| 181 | + "max_model_len": 2048, |
| 182 | + } |
| 183 | + ) |
| 184 | + |
| 185 | + def build_config_json(self) -> str: |
| 186 | + config_dict = { |
| 187 | + "architectures": self.architectures, |
| 188 | + "hidden_size": self.hidden_size, |
| 189 | + "moe_intermediate_size": self.moe_intermediate_size, |
| 190 | + "moe_num_experts": self.moe_num_experts, |
| 191 | + "moe_k": self.moe_k, |
| 192 | + "hidden_act": self.hidden_act, |
| 193 | + "num_attention_heads": self.num_attention_heads, |
| 194 | + "num_hidden_layers": self.num_hidden_layers, |
| 195 | + "dtype": "bfloat16", |
| 196 | + } |
| 197 | + |
| 198 | + tmp_dir = f"./tmp_w4a8_moe_{paddle.distributed.get_rank()}" |
| 199 | + os.makedirs(tmp_dir, exist_ok=True) |
| 200 | + with open(f"./{tmp_dir}/config.json", "w") as f: |
| 201 | + json.dump(config_dict, f) |
| 202 | + self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir) |
| 203 | + return self.model_name_or_path |
| 204 | + |
| 205 | + def test_fused_moe(self): |
| 206 | + init_distributed_environment() |
| 207 | + |
| 208 | + gating = paddle.nn.Linear(self.model_config.hidden_size, self.model_config.moe_num_experts) |
| 209 | + gating.to(dtype=paddle.float32) # it's dtype is bfloat16 default, but the forward input is float32 |
| 210 | + gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) |
| 211 | + |
| 212 | + # ep_size = paddle.distributed.get_world_size() |
| 213 | + # ep_rank = paddle.distributed.get_rank() |
| 214 | + ep_size = 1 |
| 215 | + ep_rank = 0 |
| 216 | + |
| 217 | + tp_size = 1 |
| 218 | + tp_rank = 0 |
| 219 | + |
| 220 | + nnodes = (ep_size + 7) // 8 |
| 221 | + |
| 222 | + # 这行代码必须保留,否则影响均匀性! |
| 223 | + paddle.seed(ep_rank + 100) |
| 224 | + |
| 225 | + fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes).fused_moe |
| 226 | + weight_size = fused_moe.top_k * fused_moe.hidden_size * fused_moe.moe_intermediate_size * 3 / 2 |
| 227 | + |
| 228 | + tester = OpPerformanceTester( |
| 229 | + op_name="w4a8-moe", |
| 230 | + op_fn=fused_moe, |
| 231 | + num_layers=self.model_config.num_hidden_layers, |
| 232 | + weight_size=weight_size, |
| 233 | + gate=gating, |
| 234 | + ) |
| 235 | + |
| 236 | + tester.benchmark( |
| 237 | + input_size=self.model_config.hidden_size, |
| 238 | + batch_sizes=[10, 20, 40, 60, 80, 100, 128], |
| 239 | + ) |
| 240 | + |
| 241 | + def tearDown(self) -> None: |
| 242 | + if self.model_name_or_path: |
| 243 | + print("Remove tmp model config file") |
| 244 | + shutil.rmtree(self.model_name_or_path) |
| 245 | + |
| 246 | + |
| 247 | +if __name__ == "__main__": |
| 248 | + unittest.main() |
0 commit comments