Skip to content

Commit 1e88754

Browse files
authored
support set dy-C8 from args (PaddlePaddle#4475)
1 parent 741a012 commit 1e88754

File tree

6 files changed

+33
-11
lines changed

6 files changed

+33
-11
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def __init__(
405405
# model for mtp/eagle/draft_model
406406
self.model: Optional[str] = None
407407
# quantization of model
408-
self.quantization: Optional[str] = None
408+
self.quantization: Optional[Dict[str, Any]] = None
409409
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
410410
# Fixed now
411411
self.num_gpu_block_expand_ratio: Optional[float] = 1

fastdeploy/engine/args_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DeprecatedOptionWarning,
4242
FlexibleArgumentParser,
4343
is_port_available,
44+
parse_quantization,
4445
)
4546

4647

@@ -138,7 +139,7 @@ class EngineArgs:
138139
"""
139140
dynamic load weight strategy
140141
"""
141-
quantization: str = None
142+
quantization: Optional[Dict[str, Any]] = None
142143
guided_decoding_backend: str = "off"
143144
"""
144145
Guided decoding backend.
@@ -558,7 +559,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
558559
)
559560
model_group.add_argument(
560561
"--quantization",
561-
type=str,
562+
type=parse_quantization,
562563
default=EngineArgs.quantization,
563564
help="Quantization name for the model, currentlly support "
564565
"'wint8', 'wint4',"

fastdeploy/engine/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import json
1920
import multiprocessing
2021
import os
2122
import re
@@ -450,7 +451,7 @@ def _start_worker_service(self):
450451
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
451452
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
452453
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
453-
f" --quantization {self.cfg.model_config.quantization}"
454+
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
454455
f" --ori_vocab_size {ori_vocab_size}"
455456
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
456457
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"

fastdeploy/rl/rollout_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""
1616

17+
from typing import Any, Dict, Optional
18+
1719
from fastdeploy.worker.worker_process import initialize_fd_config
1820

1921

@@ -52,7 +54,7 @@ def __init__(
5254
expert_parallel_size: int = 1,
5355
enable_expert_parallel: bool = False,
5456
ori_vocab_size: int = None,
55-
quantization: str = "None",
57+
quantization: Optional[Dict[str, Any]] = None,
5658
guided_decoding_backend: str = "off",
5759
disable_any_whitespace: bool = True,
5860
enable_logprob: bool = False,

fastdeploy/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import asyncio
1919
import codecs
2020
import importlib
21+
import json
2122
import logging
2223
import os
2324
import random
@@ -757,6 +758,16 @@ def status(self) -> dict:
757758
}
758759

759760

761+
def parse_quantization(value: str):
762+
"""
763+
Parse a JSON string into a dictionary.
764+
"""
765+
try:
766+
return json.loads(value)
767+
except ValueError:
768+
return {"quantization": value}
769+
770+
760771
# 日志使用全局访问点(兼容原有使用方式)
761772
def get_logger(name, file_name=None, without_formater=False, print_to_console=False):
762773
"""全局函数包装器,保持向后兼容"""

fastdeploy/worker/worker_process.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from fastdeploy.inter_communicator import IPCSignal
4646
from fastdeploy.model_executor.layers.quantization import get_quantization_config
4747
from fastdeploy.platforms import current_platform
48-
from fastdeploy.utils import get_logger
48+
from fastdeploy.utils import get_logger, parse_quantization
4949
from fastdeploy.worker.worker_base import WorkerBase
5050

5151
logger = get_logger("worker_process", "worker_process.log")
@@ -616,8 +616,8 @@ def parse_args():
616616

617617
parser.add_argument(
618618
"--quantization",
619-
type=str,
620-
default="None",
619+
type=json.loads,
620+
default=None,
621621
help="Quantization name for the model, currentlly support "
622622
"'wint4', 'wint8',"
623623
"default is None. The priority of this configuration "
@@ -719,6 +719,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
719719
Returns:
720720
FDConfig: Initialized FastDeploy configuration object
721721
"""
722+
# RL rollout
723+
if args.quantization is not None and isinstance(args.quantization, str):
724+
args.quantization = parse_quantization(args.quantization)
722725
paddle.set_default_dtype(args.dtype)
723726
model_config = ModelConfig(vars(args))
724727
device_config = DeviceConfig(vars(args))
@@ -789,10 +792,14 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
789792

790793
if quantization_config is not None:
791794
quant_config_name = quantization_config["quantization"]
792-
elif args.quantization != "None":
795+
elif args.quantization is not None:
793796
quantization_config = {}
794-
quant_config_name = args.quantization
795-
quantization_config["quantization"] = quant_config_name
797+
try:
798+
quantization_config.update(args.quantization)
799+
quant_config_name = quantization_config["quantization"]
800+
except:
801+
quant_config_name = args.quantization["quantization"]
802+
quantization_config["quantization"] = quant_config_name
796803
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
797804
if load_config.load_choices == "default_v1":
798805
quantization_config["is_checkpoint_bf16"] = True

0 commit comments

Comments
 (0)