Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk.
shard_offsets = [
Expand Down Expand Up @@ -638,6 +639,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""
from typing import Dict, List, Type

from fastdeploy.utils import parse_quantization

from .quant_base import QuantConfigBase

QUANTIZATION_METHODS: List[str] = [
Expand All @@ -35,6 +37,8 @@


def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
# 1.model_config.is_quantized
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
if model_config.model_format == "torch":
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
}
moe_quant_type = ""
dense_quant_type = ""
if fd_config.quant_config is None:
if fd_config.quant_config is not None:
if fd_config.quant_config.name() == "mix_quant":
moe_quant_type = fd_config.quant_config.moe_quant_type
dense_quant_type = fd_config.quant_config.dense_quant_type
Expand Down
4 changes: 1 addition & 3 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, parse_quantization
from fastdeploy.utils import get_logger
from fastdeploy.worker.worker_base import WorkerBase

logger = get_logger("worker_process", "worker_process.log")
Expand Down Expand Up @@ -643,8 +643,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
FDConfig: Initialized FastDeploy configuration object
"""
# RL rollout
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
paddle.set_default_dtype(args.dtype)
model_config = ModelConfig(vars(args))
device_config = DeviceConfig(vars(args))
Expand Down
Loading