Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
),
torch.device(DEVICES[0]),
default_vllm_config,
)
model = manager.model
assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA)
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device)
max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
),
device=device,
vllm_config=default_vllm_config,
)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
Expand Down Expand Up @@ -217,6 +219,7 @@ def test_lora_lru_cache_model_manager(
max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
),
device=device,
vllm_config=default_vllm_config,
)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
Expand Down Expand Up @@ -310,6 +313,7 @@ def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, dev
max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
),
device=device,
vllm_config=default_vllm_config,
)
assert all(x is None for x in manager.lora_index_to_id)

Expand Down Expand Up @@ -445,7 +449,7 @@ def test_lru_cache_worker_adapter_manager(
worker_adapter_manager.max_num_seqs = 4
worker_adapter_manager.max_num_batched_tokens = 2

worker_adapter_manager.create_lora_manager(dummy_model)
worker_adapter_manager.create_lora_manager(dummy_model, vllm_config)

mapping = LoRAMapping([], [])
worker_adapter_manager.set_active_adapters(
Expand Down Expand Up @@ -549,7 +553,7 @@ def test_worker_adapter_manager(

worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up, vllm_config)

dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
Expand Down Expand Up @@ -669,6 +673,7 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
),
device=device,
vllm_config=default_vllm_config,
)
model = manager.model

Expand Down
1 change: 0 additions & 1 deletion tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/lora",
"vllm/model_executor",
]

Expand Down
16 changes: 10 additions & 6 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal

import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic import ConfigDict, Field, field_validator, model_validator
from typing_extensions import Self

from vllm.config.utils import config
Expand Down Expand Up @@ -38,7 +39,7 @@ class LoRAConfig:
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: int | None = None
max_cpu_loras: int = None
Comment thread
hmellor marked this conversation as resolved.
Outdated
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
Expand Down Expand Up @@ -88,15 +89,18 @@ def compute_hash(self) -> str:
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

@field_validator("max_cpu_loras", mode="wrap")
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
if value is None:
return value
return handler(value)

@model_validator(mode="after")
def _validate_lora_config(self) -> Self:
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})."
)
raise ValueError(f"{self.max_cpu_loras=} must be >= {self.max_loras=}.")

return self

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ def create_engine_config(
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
else None,
else self.max_loras,
)
if self.enable_lora
else None
Expand Down
8 changes: 7 additions & 1 deletion vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
assert self.base_layer.quant_method.moe_quant_config is not None, (
"Fused MoE LoRA requires the base layer to have a moe_quant_config."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)
Expand All @@ -57,7 +60,7 @@
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
self._inject_lora_into_fused_moe()

def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
def _normalize_keys(self, config: dict[str, int]) -> dict[str, int]:
normalized_config = {}
for key, value in config.items():
if key.islower():
Expand Down Expand Up @@ -107,6 +110,9 @@
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
assert layer.quant_method.moe_quant_config is not None, (
"Fused MoE LoRA requires the base layer to have a moe_quant_config."
)
get_config_func = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=layer.w13_weight.size(),
Expand All @@ -117,14 +123,14 @@
M=M,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(

Check failure on line 126 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "dict[str, int]", variable has type "dict[str, int | None]") [assignment]
op_type=f"fused_moe_lora_{op_prefix}_shrink"
)
expand_config = get_config_func(

Check failure on line 129 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "dict[str, int]", variable has type "dict[str, int | None]") [assignment]
op_type=f"fused_moe_lora_{op_prefix}_expand"
)
shrink_config = self._normalize_keys(shrink_config)

Check failure on line 132 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "_normalize_keys" of "FusedMoEWithLoRA" has incompatible type "dict[str, int | None]"; expected "dict[str, int]" [arg-type]

Check failure on line 132 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "dict[str, int]", variable has type "dict[str, int | None]") [assignment]
expand_config = self._normalize_keys(expand_config)

Check failure on line 133 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "_normalize_keys" of "FusedMoEWithLoRA" has incompatible type "dict[str, int | None]"; expected "dict[str, int]" [arg-type]

Check failure on line 133 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "dict[str, int]", variable has type "dict[str, int | None]") [assignment]
return shrink_config, expand_config

def _inject_lora_into_fused_moe(self):
Expand All @@ -139,7 +145,7 @@
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None

Check failure on line 148 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "FusedMoEKernel | None" has no attribute "shared_experts" [union-attr]
else:
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
Expand All @@ -152,13 +158,13 @@
),
)

if quant_config.use_mxfp4_w4a16:

Check failure on line 161 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "FusedMoEQuantConfig | None" has no attribute "use_mxfp4_w4a16" [union-attr]
assert isinstance(
m_fused_moe_fn.impl.fused_experts,

Check failure on line 163 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "FusedMoEKernel | None" has no attribute "impl" [union-attr]
(MarlinExperts, UnfusedOAITritonExperts),
)
else:
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)

Check failure on line 167 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "FusedMoEKernel | None" has no attribute "impl" [union-attr]

def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def try_get_optimal_moe_lora_config(
dtype: str | None,
M: int,
block_shape: list[int] | None = None,
) -> dict[str, int | None]:
) -> dict[str, int]:
config = try_get_optimal_moe_config(
w1_shape, w2_shape, top_k, dtype, M, block_shape
).copy()
Expand Down
5 changes: 5 additions & 0 deletions vllm/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def check_unexpected_modules(modules: dict):
from tensorizer import TensorDeserializer

tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
if tensorizer_config.tensorizer_dir is None:
raise ValueError(
"'PEFTHelper.from_local_dir' expects 'tensorizer_dir' "
"in 'tensorizer_config_dict', but it was not found."
)
lora_tensor_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
)
Expand Down
67 changes: 24 additions & 43 deletions vllm/lora/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
from collections.abc import Callable
from typing import TypeVar
from typing import TypeVar, cast

import regex as re
import torch
Expand All @@ -30,8 +30,8 @@
replace_submodule,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models import SupportsLoRA, SupportsMultiModal
from vllm.model_executor.models.interfaces import is_pooling_model, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -45,6 +45,12 @@
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"


class SupportsLoRAModel(nn.Module, SupportsLoRA): ...


class SupportsLoRAMultiModalModel(SupportsLoRAModel, SupportsMultiModal): ...


class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
Expand All @@ -61,13 +67,13 @@ class LoRAModelManager:

def __init__(
self,
model: SupportsLoRA,
model: SupportsLoRAModel,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
vllm_config: VllmConfig | None = None,
vllm_config: VllmConfig,
):
"""Create a LoRAModelManager and adapter for a given model.

Expand All @@ -80,20 +86,26 @@ def __init__(
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.model: SupportsLoRA = model
self.model: SupportsLoRAModel = model
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, (
f"No supported LoRA modules found in {self.model.__class__.__name__}."
)

self._registered_adapters: dict[int, LoRAModel] = {}
# Dict instead of a set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
assert self.capacity >= self.lora_slots, (
f"The capacity of the manager ({self.capacity=}) must be "
f"greater than or equal to the number of LoRA slots ({self.lora_slots=})."
)
self._registered_adapters = AdapterLRUCache[LoRAModel](
self.capacity, self.deactivate_adapter
)
self._active_adapters = AdapterLRUCache[None](
self.lora_slots, self._deactivate_adapter
)
self.adapter_type = "LoRA"
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
self.vocab_size = vocab_size
Expand Down Expand Up @@ -786,40 +798,9 @@ def get_adapter(self, adapter_id: int) -> LoRAModel | None:
return self._registered_adapters.get(adapter_id)


class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_lora_fn)


class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""

def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
vllm_config: VllmConfig | None = None,
):
super().__init__(
model,
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
vllm_config,
)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter
)
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter
)

def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_adapters.cache)
Expand Down Expand Up @@ -879,7 +860,7 @@ def _pin_lora_in_gpu_cache(self, lora_id: int):


def create_lora_manager(
model: nn.Module,
model: SupportsLoRAModel,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
Expand All @@ -890,7 +871,7 @@ def create_lora_manager(
**kwargs,
) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not isinstance(model, SupportsLoRA):
if not isinstance(model, SupportsLoRAModel):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
Expand Down
5 changes: 5 additions & 0 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def from_local_dir(
tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream

if tensorizer_config.tensorizer_dir is None:
raise ValueError(
"'PEFTHelper.from_local_dir' expects 'tensorizer_dir' "
"in 'tensorizer_config_dict', but it was not found."
)
lora_config_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_config.json"
)
Expand Down
19 changes: 12 additions & 7 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,18 @@ def parse_fine_tuned_lora_name(
# LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if name.startswith("base_model.model."):
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name
else:
name = weights_mapper._map_name(name) if weights_mapper else name
if weights_mapper is not None:
prefix = ""
if name.startswith("base_model.model."):
prefix = "base_model.model."
name = name.removeprefix(prefix)
mapped_name = weights_mapper._map_name(name)
if mapped_name is None:
raise ValueError(f"Cannot map the lora weight name: {name}")
name = mapped_name
if prefix:
# recover the prefix `base_model.model.`
name = prefix + name

# In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
Expand Down
7 changes: 5 additions & 2 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __init__(
vllm_config.scheduler_config.max_num_batched_tokens
)
self.vocab_size = vllm_config.model_config.get_vocab_size()
assert vllm_config.lora_config is not None, (
"LoRA config must be provided in vLLM config when using LoRA support."
)
self.lora_config = vllm_config.lora_config

# Use get_text_config() in case of multimodal models
Expand All @@ -69,7 +72,7 @@ def is_enabled(self) -> bool:
def create_lora_manager(
self,
model: torch.nn.Module,
vllm_config: VllmConfig | None = None,
vllm_config: VllmConfig,
) -> Any:
lora_manager = create_lora_manager(
model,
Expand Down Expand Up @@ -222,7 +225,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager(
self,
model: torch.nn.Module,
vllm_config: VllmConfig | None = None,
vllm_config: VllmConfig,
) -> Any:
lora_manager = create_lora_manager(
model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import abstractmethod
from typing import cast

import torch

Expand Down
Loading