diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index cfe5b46c64c5..b0426aa4f26e 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -124,6 +124,7 @@ def test_replace_submodules(default_vllm_config, dist_init, dummy_model): max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE ), torch.device(DEVICES[0]), + default_vllm_config, ) model = manager.model assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) @@ -280,6 +281,7 @@ def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device) max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE ), device=device, + vllm_config=default_vllm_config, ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -348,6 +350,7 @@ def test_lora_lru_cache_model_manager( max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE ), device=device, + vllm_config=default_vllm_config, ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -441,6 +444,7 @@ def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, dev max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE ), device=device, + vllm_config=default_vllm_config, ) assert all(x is None for x in manager.lora_index_to_id) @@ -576,7 +580,7 @@ def test_lru_cache_worker_adapter_manager( worker_adapter_manager.max_num_seqs = 4 worker_adapter_manager.max_num_batched_tokens = 2 - worker_adapter_manager.create_lora_manager(dummy_model) + worker_adapter_manager.create_lora_manager(dummy_model, vllm_config) mapping = LoRAMapping([], []) worker_adapter_manager.set_active_adapters( @@ -680,7 +684,7 @@ def test_worker_adapter_manager( worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES) worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size - worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up, vllm_config) dummy_lora_files = f"{tmp_path}/lora_adapter" os.makedirs(dummy_lora_files, exist_ok=True) @@ -800,6 +804,7 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE ), device=device, + vllm_config=default_vllm_config, ) model = manager.model diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 7c7b0ada60d6..c5366e4f746b 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -25,11 +25,7 @@ # After fixing errors resulting from changing follow_imports # from "skip" to "silent", remove its directory from SEPARATE_GROUPS. -SEPARATE_GROUPS = [ - "tests", - # v0 related - "vllm/lora", -] +SEPARATE_GROUPS = ["tests"] # TODO(woosuk): Include the code from Megatron and HuggingFace. EXCLUDE = [ diff --git a/vllm/config/lora.py b/vllm/config/lora.py index bf47887c7e0e..5f853b513338 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal import torch -from pydantic import ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from typing_extensions import Self from vllm import envs @@ -40,7 +41,7 @@ class LoRAConfig: parallelism. Enabling this will use the fully sharded layers. At high sequence length, max rank or tensor parallel size, this is likely faster. """ - max_cpu_loras: int | None = None + max_cpu_loras: int = None # type: ignore[assignment] """Maximum number of LoRAs to store in CPU memory. Must be >= than `max_loras`.""" lora_dtype: torch.dtype | LoRADType = "auto" @@ -98,15 +99,18 @@ def compute_hash(self) -> str: hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator("max_cpu_loras", mode="wrap") + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + if value is None: + return value + return handler(value) + @model_validator(mode="after") def _validate_lora_config(self) -> Self: if self.max_cpu_loras is None: self.max_cpu_loras = self.max_loras elif self.max_cpu_loras < self.max_loras: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})." - ) + raise ValueError(f"{self.max_cpu_loras=} must be >= {self.max_loras=}.") if envs.VLLM_LORA_ENABLE_DUAL_STREAM and not current_platform.is_cuda_alike(): raise ValueError("Dual CUDA streams are only supported on CUDA platforms.") if envs.VLLM_LORA_ENABLE_DUAL_STREAM and self.fully_sharded_loras: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b3803139217..1206f2ecfad1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1989,7 +1989,7 @@ def create_engine_config( specialize_active_lora=self.specialize_active_lora, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 - else None, + else self.max_loras, ) if self.enable_lora else None diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 284ac54997fb..57c0355a4fcc 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -34,6 +34,9 @@ def __init__(self, base_layer: FusedMoE) -> None: assert not self.base_layer.use_ep, ( "EP support for Fused MoE LoRA is not implemented yet." ) + assert self.base_layer.quant_method.moe_quant_config is not None, ( + "Fused MoE LoRA requires the base layer to have a moe_quant_config." + ) assert not self.base_layer.quant_method.is_monolithic, ( "Monolithic kernels are not supported for Fused MoE LoRA." ) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 1b8083f5c4d1..cd222a0f9788 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -90,7 +90,7 @@ def try_get_optimal_moe_lora_config( top_k: int, dtype: str | None, M: int, -) -> dict[str, int | None]: +) -> dict[str, int]: # LoRA shrink/expand operates on bf16/fp16 adapters regardless of the # base MoE weight's block-wise quantization, so block_shape is omitted # from the config lookup — the non-quantized branch in get_default_config diff --git a/vllm/lora/lora_model.py b/vllm/lora/lora_model.py index 7c1dd39bb5e3..7dffbe2384b1 100644 --- a/vllm/lora/lora_model.py +++ b/vllm/lora/lora_model.py @@ -196,6 +196,11 @@ def check_unexpected_modules(modules: dict): from tensorizer import TensorDeserializer tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + if tensorizer_config.tensorizer_dir is None: + raise ValueError( + "'PEFTHelper.from_local_dir' expects 'tensorizer_dir' " + "in 'tensorizer_config_dict', but it was not found." + ) lora_tensor_path = os.path.join( tensorizer_config.tensorizer_dir, "adapter_model.tensors" ) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 52ff8ebc91f3..1839875c7360 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.models import ( SupportsLoRA, + SupportsMultiModal, is_pooling_model, supports_multimodal, ) @@ -49,6 +50,12 @@ DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model" +class SupportsLoRAModel(nn.Module, SupportsLoRA): ... + + +class SupportsLoRAMultiModalModel(SupportsLoRAModel, SupportsMultiModal): ... + + class AdapterLRUCache(LRUCache[int, T]): def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) @@ -65,13 +72,13 @@ class LoRAModelManager: def __init__( self, - model: SupportsLoRA, + model: SupportsLoRAModel, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, device: torch.device, - vllm_config: VllmConfig | None = None, + vllm_config: VllmConfig, ): """Create a LoRAModelManager and adapter for a given model. @@ -84,20 +91,26 @@ def __init__( vocab_size: the vocab size of the model. lora_config: the LoRA configuration. """ - self.model: SupportsLoRA = model + self.model: SupportsLoRAModel = model self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, ( f"No supported LoRA modules found in {self.model.__class__.__name__}." ) - self._registered_adapters: dict[int, LoRAModel] = {} - # Dict instead of a set for compatibility with LRUCache. - self._active_adapters: dict[int, None] = {} + assert self.capacity >= self.lora_slots, ( + f"The capacity of the manager ({self.capacity=}) must be " + f"greater than or equal to the number of LoRA slots ({self.lora_slots=})." + ) + self._registered_adapters = AdapterLRUCache[LoRAModel]( + self.capacity, self.deactivate_adapter + ) + self._active_adapters = AdapterLRUCache[None]( + self.lora_slots, self._deactivate_adapter + ) self.adapter_type = "LoRA" self.lora_config = lora_config self.device = device self.max_num_seqs = max_num_seqs - assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: list[int | None] = [None] * self.lora_slots self.vocab_size = vocab_size @@ -879,40 +892,9 @@ def get_adapter(self, adapter_id: int) -> LoRAModel | None: return self._registered_adapters.get(adapter_id) -class LoRALRUCache(AdapterLRUCache[LoRAModel]): - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): - super().__init__(capacity, deactivate_lora_fn) - - class LRUCacheLoRAModelManager(LoRAModelManager): """A model manager that manages multiple LoRAs with LRU cache.""" - def __init__( - self, - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - vllm_config: VllmConfig | None = None, - ): - super().__init__( - model, - max_num_seqs, - max_num_batched_tokens, - vocab_size, - lora_config, - device, - vllm_config, - ) - self._registered_adapters: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_adapter - ) - self._active_adapters: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_adapter - ) - def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) @@ -972,7 +954,7 @@ def _pin_lora_in_gpu_cache(self, lora_id: int): def create_lora_manager( - model: nn.Module, + model: SupportsLoRAModel, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, @@ -983,7 +965,7 @@ def create_lora_manager( **kwargs, ) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not isinstance(model, SupportsLoRA): + if not isinstance(model, SupportsLoRAModel): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 975c3d8fc0a7..90b8d59c5f1f 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -91,6 +91,11 @@ def from_local_dir( tensorizer_args = tensorizer_config._construct_tensorizer_args() from tensorizer.stream_io import open_stream + if tensorizer_config.tensorizer_dir is None: + raise ValueError( + "'PEFTHelper.from_local_dir' expects 'tensorizer_dir' " + "in 'tensorizer_config_dict', but it was not found." + ) lora_config_path = os.path.join( tensorizer_config.tensorizer_dir, "adapter_config.json" ) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2991447a6ad4..423119a26704 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -171,13 +171,18 @@ def parse_fine_tuned_lora_name( # LoRA weight qualified name usually starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. - if name.startswith("base_model.model."): - name = name.replace("base_model.model.", "") - name = weights_mapper._map_name(name) if weights_mapper else name - # recover the prefix `base_model.model.` - name = "base_model.model." + name - else: - name = weights_mapper._map_name(name) if weights_mapper else name + if weights_mapper is not None: + prefix = "" + if name.startswith("base_model.model."): + prefix = "base_model.model." + name = name.removeprefix(prefix) + mapped_name = weights_mapper._map_name(name) + if mapped_name is None: + raise ValueError(f"Cannot map the lora weight name: {name}") + name = mapped_name + if prefix: + # recover the prefix `base_model.model.` + name = prefix + name # In some situations, we may not start with `base_model.model.`. # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 5b9bf2d76fbb..6d90fb01eac9 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -45,6 +45,9 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.vocab_size = vllm_config.model_config.get_vocab_size() + assert vllm_config.lora_config is not None, ( + "LoRA config must be provided in vLLM config when using LoRA support." + ) self.lora_config = vllm_config.lora_config # Use get_text_config() in case of multimodal models @@ -81,7 +84,7 @@ def is_enabled(self) -> bool: def create_lora_manager( self, model: torch.nn.Module, - vllm_config: VllmConfig | None = None, + vllm_config: VllmConfig, ) -> Any: lora_manager = create_lora_manager( model, @@ -235,7 +238,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): def create_lora_manager( self, model: torch.nn.Module, - vllm_config: VllmConfig | None = None, + vllm_config: VllmConfig, ) -> Any: lora_manager = create_lora_manager( model, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6a4c1f3c47ef..9b0994867f57 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -268,10 +268,13 @@ def __init__( self.quant_config = quant_config self.prefix = prefix self.allow_fp8_block_shape_mismatch = False + self.quant_method: QuantizeMethodBase if quant_config is None: - self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() + self.quant_method = UnquantizedLinearMethod() + elif quant_method := quant_config.get_quant_method(self, prefix=prefix): + self.quant_method = quant_method else: - self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + raise ValueError("All linear layers should support quant method.") self.return_bias = return_bias self.disable_tp = disable_tp self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 @@ -335,8 +338,6 @@ def __init__( disable_tp=disable_tp, ) - # All the linear layer supports quant method. - assert self.quant_method is not None self.quant_method.create_weights( self, self.input_size, @@ -389,8 +390,6 @@ def forward( x: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None - assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) if not self.return_bias: @@ -474,7 +473,6 @@ def __init__( self._maybe_allow_fp8_block_shape_mismatch() self.gather_output = gather_output - assert self.quant_method is not None self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -583,7 +581,6 @@ def forward( bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output and self.tp_size > 1: @@ -1463,7 +1460,6 @@ def __init__( self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - assert self.quant_method is not None self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -1553,7 +1549,6 @@ def forward( input_parallel = split_input[self.tp_rank].contiguous() # Matrix multiply. - assert self.quant_method is not None # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias