|
42 | 42 | from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput |
43 | 43 | from sglang.srt.layers.radix_attention import RadixAttention |
44 | 44 | from sglang.srt.layers.sampler import Sampler |
| 45 | +from sglang.srt.layers.torchao_utils import torchao_quantize_param_data |
| 46 | +from sglang.srt.managers.schedule_batch import global_server_args_dict |
45 | 47 | from sglang.srt.model_executor.forward_batch_info import InputMetadata |
46 | 48 |
|
47 | 49 |
|
@@ -299,6 +301,7 @@ def __init__( |
299 | 301 | super().__init__() |
300 | 302 | self.config = config |
301 | 303 | self.quant_config = quant_config |
| 304 | + self.torchao_config = global_server_args_dict["torchao_config"] |
302 | 305 | self.model = LlamaModel(config, quant_config=quant_config) |
303 | 306 | self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) |
304 | 307 | self.logits_processor = LogitsProcessor(config) |
@@ -361,6 +364,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
361 | 364 | weight_loader = getattr(param, "weight_loader", default_weight_loader) |
362 | 365 | weight_loader(param, loaded_weight) |
363 | 366 |
|
| 367 | + if self.torchao_config: |
| 368 | + if name.endswith("proj.weight") and param.ndim == 2: |
| 369 | + params_dict[name] = torchao_quantize_param_data( |
| 370 | + param, self.torchao_config |
| 371 | + ) |
| 372 | + |
| 373 | + if self.torchao_config: |
| 374 | + # quantizing the loaded, stacked params, e.g. "...qkv_proj" |
| 375 | + stacked_params = set(entry[0] for entry in stacked_params_mapping) |
| 376 | + for param_suffix in stacked_params: |
| 377 | + for name in params_dict: |
| 378 | + if param_suffix in name: |
| 379 | + param = params_dict[name] |
| 380 | + params_dict[name] = torchao_quantize_param_data( |
| 381 | + param, self.torchao_config |
| 382 | + ) |
| 383 | + |
| 384 | + self.load_state_dict(params_dict, assign=True) |
| 385 | + |
364 | 386 |
|
365 | 387 | class Phi3ForCausalLM(LlamaForCausalLM): |
366 | 388 | pass |
|
0 commit comments