Skip to content

Commit 799fcde

Browse files
committed
update
1 parent 497034f commit 799fcde

File tree

1 file changed

+11
-11
lines changed
  • fastdeploy/model_executor/layers/moe

1 file changed

+11
-11
lines changed

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,19 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
177177
if shard_id is None:
178178
# 1.gate up fused in disk
179179
model_format = getattr(param, "model_format", "")
180-
is_opensource_weight = model_format == "torch"
180+
is_torch_model = model_format == "torch"
181181
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
182182
per_rank = output_size // 2
183183
start = self.tp_rank * per_rank
184184
loaded_weight_shard_gate = slice_fn(
185-
loaded_weight, is_opensource_weight ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
185+
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
186186
)
187187
self._load_gate_up_weight(
188188
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
189189
)
190190
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
191191
loaded_weight_shard_up = slice_fn(
192-
loaded_weight, is_opensource_weight ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
192+
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
193193
)
194194
self._load_gate_up_weight(
195195
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
@@ -207,18 +207,18 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
207207

208208
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
209209
model_format = getattr(param, "model_format", "")
210-
is_opensource_weight = model_format == "torch"
210+
is_torch_model = model_format == "torch"
211211
if self.tp_size > 1 and not is_sharded:
212-
weight_shard_dim = is_opensource_weight ^ shard_dim
213-
weight_dim = -1 if weight_shard_dim else 0
212+
tp_shard_dim = is_torch_model ^ shard_dim
213+
weight_dim = -1 if tp_shard_dim else 0
214214
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
215215
size = loaded_weight.shape[weight_dim]
216216
else:
217217
size = loaded_weight.get_shape()[weight_dim]
218218
block_size = size // self.tp_size
219219
shard_offset = self.tp_rank * block_size
220220
shard_size = (self.tp_rank + 1) * block_size
221-
loaded_weight = slice_fn(loaded_weight, weight_shard_dim, shard_offset, shard_size)
221+
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
222222
loaded_weight = get_tensor(loaded_weight)
223223
expert_param = param[expert_id - self.expert_id_offset]
224224
dim = -1 if shard_dim else 0
@@ -249,18 +249,18 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
249249

250250
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
251251
model_format = getattr(param, "model_format", "")
252-
is_opensource_weight = model_format == "torch"
252+
is_torch_model = model_format == "torch"
253253
if self.tp_size > 1 and shard_dim is not None:
254-
weight_shard_dim = is_opensource_weight ^ shard_dim
255-
dim = -1 if weight_shard_dim else 0
254+
tp_shard_dim = is_torch_model ^ shard_dim
255+
dim = -1 if tp_shard_dim else 0
256256
if isinstance(loaded_weight, paddle.Tensor):
257257
size = loaded_weight.shape[dim]
258258
else:
259259
size = loaded_weight.get_shape()[dim]
260260
block_size = size // self.tp_size
261261
shard_offset = self.tp_rank * block_size
262262
shard_size = (self.tp_rank + 1) * block_size
263-
loaded_weight = slice_fn(loaded_weight, weight_shard_dim, shard_offset, shard_size)
263+
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
264264
loaded_weight = get_tensor(loaded_weight)
265265
expert_param = param[expert_id - self.expert_id_offset]
266266
if hasattr(param, "tensor_track"):

0 commit comments

Comments
 (0)