diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a37e88a387f..ad7c07dc8cd 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -166,6 +166,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Initialize buffers (e.g. rotary embedding inverse frequency) self.init_buffers(self.model) + # Initialize parameters + self.init_parameters(self.model) + # Move remaining meta tensors to device (should happen last) self.meta_to_empty(self.model) @@ -298,6 +301,25 @@ def init_buffers(self, module: nn.Module): for child in module.children(): self.init_buffers(child) + def init_parameters(self, module: nn.Module): + """ + If a `parameter` is on the `meta` device, then its parent + `module` is the original module created by: + + ```python + with torch.device("meta"): + self.model: PreTrainedModel = AutoModel.from_config(...) + ``` + """ + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like(param.data, + device=self.device_config.device)) + setattr(module, name, new_param) + for child in module.children(): + self.init_parameters(child) + def meta_to_empty(self, module: nn.Module): tensors = list(chain(module.buffers(), module.parameters())) if tensors and all(t.device == torch.device("meta") for t in tensors): @@ -342,6 +364,7 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) + loaded_params = set[str]() for name, loaded_weight in weights: # Use "model" instead of base_model_prefix because