Skip to content

Commit d40fe1c

Browse files
committed
Add torchao quant for mixtral
Summary: Similar to sgl-project#1341 we add torchao quantization to mixtral model Test Plan: Note: compile is not working yet, and I can't install torchnightly locally and make it work either. I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 
Warmup ... Prefill. latency: 0.05532 s, throughput: 2313.73 token/s Decode. latency: 0.00896 s, throughput: 111.65 token/s Decode. latency: 0.00833 s, throughput: 120.04 token/s Decode. latency: 0.00869 s, throughput: 115.06 token/s Decode. latency: 0.00842 s, throughput: 118.79 token/s Decode. median latency: 0.00855 s, median throughput: 116.89 token/s Total. latency: 0.090 s, throughput: 1471.26 token/s Benchmark ... Prefill. latency: 0.04294 s, throughput: 2980.61 token/s Decode. latency: 0.00839 s, throughput: 119.12 token/s Decode. latency: 0.00828 s, throughput: 120.78 token/s Decode. latency: 0.00857 s, throughput: 116.64 token/s Decode. latency: 0.00853 s, throughput: 117.19 token/s Decode. latency: 0.00859 s, throughput: 116.39 token/s Decode. median latency: 0.00853 s, median throughput: 117.17 token/s Total. latency: 0.111 s, throughput: 1226.84 token/s python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128 Warmup ... Prefill. latency: 0.06413 s, throughput: 1996.05 token/s Decode. latency: 0.00764 s, throughput: 130.84 token/s Decode. latency: 0.00748 s, throughput: 133.73 token/s Decode. latency: 0.00725 s, throughput: 137.84 token/s Decode. latency: 0.00721 s, throughput: 138.74 token/s Decode. median latency: 0.00737 s, median throughput: 135.76 token/s Total. latency: 0.094 s, throughput: 1408.61 token/s Benchmark ... Prefill. latency: 0.05239 s, throughput: 2443.43 token/s Decode. latency: 0.00739 s, throughput: 135.25 token/s Decode. latency: 0.00720 s, throughput: 138.90 token/s Decode. latency: 0.00718 s, throughput: 139.21 token/s Decode. latency: 0.00722 s, throughput: 138.42 token/s Decode. latency: 0.00745 s, throughput: 134.30 token/s Decode. median latency: 0.00731 s, median throughput: 136.82 token/s Total. latency: 0.111 s, throughput: 1223.51 token/s A100, no compile python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo max_total_num_tokens=199454 Warmup ... Prefill. latency: 0.06958 s, throughput: 1839.60 token/s Decode. latency: 0.02343 s, throughput: 42.68 token/s Decode. latency: 0.02342 s, throughput: 42.70 token/s Decode. latency: 0.02368 s, throughput: 42.23 token/s Decode. latency: 0.02337 s, throughput: 42.80 token/s Decode. median latency: 0.02342 s, median throughput: 42.69 token/s Total. latency: 0.163 s, throughput: 807.48 token/s Benchmark ... Prefill. latency: 0.05767 s, throughput: 2219.36 token/s Decode. latency: 0.02293 s, throughput: 43.61 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. latency: 0.02029 s, throughput: 49.29 token/s Decode. latency: 0.02024 s, throughput: 49.41 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. median latency: 0.02025 s, median throughput: 49.39 token/s Total. latency: 0.222 s, throughput: 611.87 token/s Reviewers: Subscribers: Tasks: Tags:
1 parent 70b6802 commit d40fe1c

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

python/sglang/srt/models/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
416416
stacked_params = set(entry[0] for entry in stacked_params_mapping)
417417
for param_suffix in stacked_params:
418418
for name in params_dict:
419-
if param_suffix in name:
420-
param = params_dict[name]
419+
param = params_dict[name]
420+
if param_suffix in name and name.endswith("proj.weight") and param.ndim == 2:
421421
params_dict[name] = torchao_quantize_param_data(
422422
param, self.torchao_config
423423
)

python/sglang/srt/models/mixtral.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from sglang.srt.layers.layernorm import RMSNorm
4242
from sglang.srt.layers.logits_processor import LogitsProcessor
4343
from sglang.srt.layers.radix_attention import RadixAttention
44+
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
45+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4446
from sglang.srt.model_executor.forward_batch_info import InputMetadata
4547

4648

@@ -296,6 +298,7 @@ def __init__(
296298
super().__init__()
297299
self.config = config
298300
self.quant_config = quant_config
301+
self.torchao_config = global_server_args_dict["torchao_config"]
299302
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300303
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301304
self.logits_processor = LogitsProcessor(config)
@@ -375,6 +378,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
375378
param, "weight_loader", default_weight_loader
376379
)
377380
weight_loader(param, loaded_weight)
381+
if self.torchao_config:
382+
if name.endswith("proj.weight") and param.ndim == 2:
383+
params_dict[name] = torchao_quantize_param_data(
384+
param, self.torchao_config
385+
)
386+
387+
if self.torchao_config:
388+
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
389+
stacked_params = set(entry[0] for entry in stacked_params_mapping)
390+
for param_suffix in stacked_params:
391+
for name in params_dict:
392+
param = params_dict[name]
393+
if param_suffix in name and name.endswith("proj.weight") and param.ndim == 2:
394+
params_dict[name] = torchao_quantize_param_data(
395+
param, self.torchao_config
396+
)
397+
398+
self.load_state_dict(params_dict, assign=True)
399+
378400

379401

380402
EntryClass = MixtralForCausalLM

python/sglang/srt/models/qwen2_moe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from sglang.srt.layers.layernorm import RMSNorm
4848
from sglang.srt.layers.logits_processor import LogitsProcessor
4949
from sglang.srt.layers.radix_attention import RadixAttention
50+
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
51+
from sglang.srt.managers.schedule_batch import global_server_args_dict
5052
from sglang.srt.model_executor.forward_batch_info import InputMetadata
5153

5254

@@ -359,6 +361,7 @@ def __init__(
359361
super().__init__()
360362
self.config = config
361363
self.quant_config = quant_config
364+
self.torchao_config = global_server_args_dict["torchao_config"]
362365
self.model = Qwen2MoeModel(config, cache_config, quant_config)
363366
self.lm_head = ParallelLMHead(
364367
config.vocab_size, config.hidden_size, quant_config=quant_config
@@ -450,6 +453,24 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
450453
param, "weight_loader", default_weight_loader
451454
)
452455
weight_loader(param, loaded_weight)
456+
if self.torchao_config:
457+
if name.endswith("proj.weight") and param.ndim == 2:
458+
params_dict[name] = torchao_quantize_param_data(
459+
param, self.torchao_config
460+
)
461+
if self.torchao_config:
462+
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
463+
stacked_params = set(entry[0] for entry in stacked_params_mapping)
464+
stacked_params.union(set(entry[0] for entry in expert_params_mapping))
465+
for param_suffix in stacked_params:
466+
for name in params_dict:
467+
param = params_dict[name]
468+
if param_suffix in name and param.ndim == 2:
469+
params_dict[name] = torchao_quantize_param_data(
470+
param, self.torchao_config
471+
)
472+
473+
self.load_state_dict(params_dict, assign=True)
453474

454475

455476
EntryClass = Qwen2MoeForCausalLM

0 commit comments

Comments
 (0)