diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ff08f182ad..4ee301d566 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -557,12 +557,7 @@ def get_weight_iter(config): return iter def model_load_weights(model, iter): - model.load_weights(iter) - for _, module in self.model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) + DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) return model with set_default_torch_dtype(self.model_config.dtype): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 94b02c6f57..b093c56d37 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,20 +374,27 @@ def load_model( self.load_config, ) - model.load_weights(self._get_all_weights(model_config, model)) + self.load_weights_and_postprocess( + model, self._get_all_weights(model_config, model), target_device + ) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) return model.eval() + @staticmethod + def load_weights_and_postprocess(model, weights, target_device): + model.load_weights(weights) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + class LayeredModelLoader(DefaultModelLoader): """Model loader that loads weights layer by layer so that one can quantize a