@@ -177,19 +177,19 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
177177 if shard_id is None :
178178 # 1.gate up fused in disk
179179 model_format = getattr (param , "model_format" , "" )
180- is_opensource_weight = model_format == "torch"
180+ is_torch_model = model_format == "torch"
181181 output_size = param [expert_id - self .expert_id_offset ].shape [SHARD_ID_TO_SHARDED_DIM ["gate" ]]
182182 per_rank = output_size // 2
183183 start = self .tp_rank * per_rank
184184 loaded_weight_shard_gate = slice_fn (
185- loaded_weight , is_opensource_weight ^ SHARD_ID_TO_SHARDED_DIM ["gate" ], start , start + per_rank
185+ loaded_weight , is_torch_model ^ SHARD_ID_TO_SHARDED_DIM ["gate" ], start , start + per_rank
186186 )
187187 self ._load_gate_up_weight (
188188 param , expert_id , loaded_weight_shard_gate , "gate" , SHARD_ID_TO_SHARDED_DIM ["gate" ], is_sharded = True
189189 )
190190 start_up = output_size // 2 * self .tp_size + self .tp_rank * per_rank
191191 loaded_weight_shard_up = slice_fn (
192- loaded_weight , is_opensource_weight ^ SHARD_ID_TO_SHARDED_DIM ["up" ], start_up , start_up + per_rank
192+ loaded_weight , is_torch_model ^ SHARD_ID_TO_SHARDED_DIM ["up" ], start_up , start_up + per_rank
193193 )
194194 self ._load_gate_up_weight (
195195 param , expert_id , loaded_weight_shard_up , "up" , SHARD_ID_TO_SHARDED_DIM ["up" ], is_sharded = True
@@ -207,18 +207,18 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
207207
208208 def _load_gate_up_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None , is_sharded = False ):
209209 model_format = getattr (param , "model_format" , "" )
210- is_opensource_weight = model_format == "torch"
210+ is_torch_model = model_format == "torch"
211211 if self .tp_size > 1 and not is_sharded :
212- weight_shard_dim = is_opensource_weight ^ shard_dim
213- weight_dim = - 1 if weight_shard_dim else 0
212+ tp_shard_dim = is_torch_model ^ shard_dim
213+ weight_dim = - 1 if tp_shard_dim else 0
214214 if isinstance (loaded_weight , (np .ndarray , paddle .Tensor )):
215215 size = loaded_weight .shape [weight_dim ]
216216 else :
217217 size = loaded_weight .get_shape ()[weight_dim ]
218218 block_size = size // self .tp_size
219219 shard_offset = self .tp_rank * block_size
220220 shard_size = (self .tp_rank + 1 ) * block_size
221- loaded_weight = slice_fn (loaded_weight , weight_shard_dim , shard_offset , shard_size )
221+ loaded_weight = slice_fn (loaded_weight , tp_shard_dim , shard_offset , shard_size )
222222 loaded_weight = get_tensor (loaded_weight )
223223 expert_param = param [expert_id - self .expert_id_offset ]
224224 dim = - 1 if shard_dim else 0
@@ -249,18 +249,18 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
249249
250250 def _load_down_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None ):
251251 model_format = getattr (param , "model_format" , "" )
252- is_opensource_weight = model_format == "torch"
252+ is_torch_model = model_format == "torch"
253253 if self .tp_size > 1 and shard_dim is not None :
254- weight_shard_dim = is_opensource_weight ^ shard_dim
255- dim = - 1 if weight_shard_dim else 0
254+ tp_shard_dim = is_torch_model ^ shard_dim
255+ dim = - 1 if tp_shard_dim else 0
256256 if isinstance (loaded_weight , paddle .Tensor ):
257257 size = loaded_weight .shape [dim ]
258258 else :
259259 size = loaded_weight .get_shape ()[dim ]
260260 block_size = size // self .tp_size
261261 shard_offset = self .tp_rank * block_size
262262 shard_size = (self .tp_rank + 1 ) * block_size
263- loaded_weight = slice_fn (loaded_weight , weight_shard_dim , shard_offset , shard_size )
263+ loaded_weight = slice_fn (loaded_weight , tp_shard_dim , shard_offset , shard_size )
264264 loaded_weight = get_tensor (loaded_weight )
265265 expert_param = param [expert_id - self .expert_id_offset ]
266266 if hasattr (param , "tensor_track" ):
0 commit comments