Skip to content

Commit dc59172

Browse files
authored
[loader]supoort wint2 backend (PaddlePaddle#6139)
* support wint2 * update
1 parent f18f3b9 commit dc59172

20 files changed

Lines changed: 86 additions & 11 deletions

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,7 @@ def __init__(
12251225
args,
12261226
):
12271227
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
1228+
self.is_pre_sharded: bool = False
12281229
self.dynamic_load_weight: bool = False
12291230
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
12301231
self.rsync_config: Optional[Dict[str, Any]] = None

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def weight_loader(self, param, loaded_weight, shard_id=None):
255255
else:
256256
loaded_weight = loaded_weight.cast(param.dtype)
257257

258-
if output_dim is None:
258+
if output_dim is None or self.fd_config.load_config.is_pre_sharded:
259259
assert (
260260
param.shape == loaded_weight.shape
261261
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"

fastdeploy/model_executor/layers/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
556556
loaded_weight = get_tensor(loaded_weight)
557557
loaded_weight = loaded_weight.transpose([1, 0])
558558
# Tensor parallelism splits the weight along the output_dim
559-
if self.tp_size > 1 and output_dim is not None:
559+
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
560560
dim = -1 if output_dim else 0
561561
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
562562
size = loaded_weight.shape[dim]
@@ -713,7 +713,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
713713
loaded_weight = get_tensor(loaded_weight)
714714
loaded_weight = loaded_weight.transpose([1, 0])
715715
# Tensor parallelism splits the weight along the output_dim
716-
if self.tp_size > 1 and output_dim is not None:
716+
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
717717
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
718718
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
719719
shard_offset = shard_id * block_size

fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import fastdeploy
2323
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
24+
from fastdeploy.model_executor.utils import set_weight_attrs
2425
from fastdeploy.utils import ceil_div
2526

2627
from ..quantization.quant_base import QuantMethodBase
@@ -154,6 +155,22 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
154155
default_initializer=paddle.nn.initializer.Constant(0),
155156
),
156157
)
158+
for weight_name in [
159+
"up_gate_proj_weight",
160+
"down_proj_weight",
161+
"up_gate_proj_weight_scale",
162+
"down_proj_weight_scale",
163+
"up_gate_proj_super_scales",
164+
"down_proj_super_scales",
165+
"up_gate_proj_code_scale",
166+
"down_proj_code_scale",
167+
"up_gate_proj_code_zp",
168+
"down_proj_code_zp",
169+
]:
170+
set_weight_attrs(
171+
getattr(layer, weight_name),
172+
extra_weight_attrs,
173+
)
157174

158175

159176
class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
@@ -164,6 +181,24 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
164181
def __init__(self, quant_config):
165182
super().__init__(quant_config)
166183

184+
def process_weights_after_loading(self, layer):
185+
if self.quant_config.is_checkpoint_bf16:
186+
# dynamic quantize
187+
return
188+
w1_shape = layer.up_gate_proj_weight.shape
189+
up_gate_proj_weight = layer.up_gate_proj_weight.reshape(
190+
[w1_shape[0], w1_shape[1] // 16, 16, w1_shape[2] // 8, 8]
191+
)
192+
up_gate_proj_weight = paddle.transpose(up_gate_proj_weight, perm=[0, 3, 1, 4, 2])
193+
up_gate_proj_weight = up_gate_proj_weight.reshape(w1_shape)
194+
layer.up_gate_proj_weight.data = up_gate_proj_weight
195+
196+
w2_shape = layer.down_proj_weight.shape
197+
down_proj_weight = layer.down_proj_weight.reshape([w2_shape[0], w2_shape[1] // 16, 16, w2_shape[2] // 8, 8])
198+
down_proj_weight = paddle.transpose(down_proj_weight, perm=[0, 3, 1, 4, 2])
199+
down_proj_weight = down_proj_weight.reshape(w2_shape)
200+
layer.down_proj_weight.data = down_proj_weight
201+
167202
def process_loaded_weights(self, layer, weights) -> None:
168203
"""
169204
process_loaded_weights

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def weight_loader(
316316
)
317317

318318
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
319-
if self.tp_size > 1 and not is_sharded:
319+
if self.tp_size > 1 and not is_sharded and not self.fd_config.load_config.is_pre_sharded:
320320
tp_shard_dim = shard_dim
321321
weight_dim = -1 if tp_shard_dim else 0
322322
size = loaded_weight.shape[weight_dim]
@@ -371,7 +371,7 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
371371
h2d_copy(dst=expert_param, src=loaded_weight)
372372

373373
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
374-
if self.tp_size > 1 and shard_dim is not None:
374+
if self.tp_size > 1 and shard_dim is not None and not self.fd_config.load_config.is_pre_sharded:
375375
tp_shard_dim = shard_dim
376376
dim = -1 if tp_shard_dim else 0
377377
size = loaded_weight.shape[dim]
@@ -397,7 +397,7 @@ def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim
397397
h2d_copy(dst=expert_param, src=loaded_weight)
398398

399399
def _load_fused_experts_weight(self, param, loaded_weight):
400-
if self.tp_size > 1 and self.moe_quant_type != "mxfp4":
400+
if self.tp_size > 1 and self.moe_quant_type != "mxfp4" and not self.fd_config.load_config.is_pre_sharded:
401401
dim = -1
402402
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
403403
size = loaded_weight.shape[dim]

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@ def load_weights_from_cache(model, weights_iterator):
7676
model_sublayer.process_weights_after_loading()
7777

7878

79+
def get_model_path(fd_config: FDConfig):
80+
model_path = fd_config.model_config.model
81+
rank_dirs = [
82+
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
83+
]
84+
if len(rank_dirs) > 1:
85+
local_rank = fd_config.parallel_config.tensor_parallel_rank
86+
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
87+
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
88+
model_path = os.path.join(model_path, f"rank{local_rank}")
89+
fd_config.load_config.is_pre_sharded = True
90+
return model_path
91+
92+
7993
def get_weight_iterator(model_path: str):
8094
files_list, ordered_weight_map, use_safetensors, is_key_ordered = get_all_weights_file(model_path)
8195
if use_safetensors:
@@ -404,10 +418,8 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
404418
"""
405419
load_pre_sharded_checkpoint
406420
"""
407-
408421
state_dict = {}
409-
safetensor_files, _, _, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
410-
weights_iterator = safetensors_weights_iterator(safetensor_files)
422+
weights_iterator = get_weight_iterator(os.path.join(model_path, f"rank{local_rank}"))
411423
for name, weight in weights_iterator:
412424
state_dict[name] = weight.clone()
413425
return state_dict

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
2222
from fastdeploy.model_executor.load_weight_utils import (
23+
get_model_path,
2324
get_weight_iterator,
2425
is_weight_cache_enabled,
2526
load_weights_from_cache,
@@ -51,7 +52,8 @@ def clean_memory_fragments(self) -> None:
5152
@save_model()
5253
@measure_time()
5354
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
54-
weights_iterator = get_weight_iterator(fd_config.model_config.model)
55+
model_path = get_model_path(fd_config)
56+
weights_iterator = get_weight_iterator(model_path)
5557
if enable_cache:
5658
load_weights_from_cache(model, weights_iterator)
5759
else:

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ def load_weights(self, weights_iterator) -> None:
677677
params_dict = dict(self.named_parameters())
678678
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
679679
for loaded_weight_name, loaded_weight in weights_iterator:
680+
logger.debug(f"Loading weight: {loaded_weight_name}")
680681
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
681682
for param_name, weight_name, shard_id in stacked_params_mapping:
682683
if weight_name not in loaded_weight_name:

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ def load_weights(self, weights_iterator) -> None:
633633
)
634634

635635
for loaded_weight_name, loaded_weight in weights_iterator:
636+
logger.debug(f"Loading weight: {loaded_weight_name}")
636637
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
637638
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
638639
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ def load_weights(self, weights_iterator) -> None:
705705
expert_id = None
706706
shard_id = None
707707
for loaded_weight_name, loaded_weight in weights_iterator:
708+
logger.debug(f"Loading weight: {loaded_weight_name}")
708709
loaded_weight_name = (
709710
self.process_weights_before_loading_fn(loaded_weight_name)
710711
if getattr(self, "process_weights_before_loading_fn", None)

0 commit comments

Comments
 (0)