Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
- * `Qwen2ForCausalLM`
* QwQ, Qwen2
* `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc.
* Qwen2
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.
* ✅︎
* ✅︎
- * `Qwen2MoeForCausalLM`
Expand Down Expand Up @@ -665,6 +665,13 @@ On the other hand, modalities separated by `/` are mutually exclusive.

- e.g.: `T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.

### ColQwen2VL

- **Model Name**: ColQwen2VL
- **Description**: Implements the ColQwen2 model for efficient document retrieval with vision-language capabilities. This model is compatible with the transformers' ColQwen2 class and is designed to handle complex multimodal tasks involving text and image data.
- **Supported Modalities**: Text + Image
- **Example Use Cases**: Document Retrieval (Text-to-Image Retrieval) using embedding outputs.

See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model.

:::{important}
Expand Down Expand Up @@ -692,23 +699,8 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
:::

### Generative Models
### Other Models

See [this page](#generative-models) for more information on how to use generative models.

#### Text Generation (`--task generate`)

:::{list-table}
:widths: 25 25 15 20 5 5 5
:header-rows: 1

- * Architecture
* Models
* Inputs
* Example HF Models
* [LoRA](#lora-adapter)
* [PP](#distributed-serving)
* [V1](gh-issue:8779)
- * `AriaForConditionalGeneration`
* Aria
* T + I<sup>+</sup>
Expand Down Expand Up @@ -1011,7 +1003,7 @@ _________________

## Model Support Policy

At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Heres how we manage third-party model support:
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here's how we manage third-party model support:

1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated!

Expand Down
33 changes: 33 additions & 0 deletions examples/offline_inference/vision_language_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,38 @@ def run_e5_v(query: Query):
)


def run_colqwen2vlm(query: Query):
if query["modality"] == "text":
text = query["text"]
prompt = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
image = None
elif query["modality"] == "image":
text = "Describe the image."
prompt = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{text}<|im_end|>\n"
"<|im_start|>assistant\n")
image = query["image"]
else:
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")

llm = LLM(
model="vidore/colqwen2-v1.0-merged",
# model="vidore/colqwen2-1.0-hf-internal",
task="embed",
trust_remote_code=True,
# dtype=torch.bfloat16,
)

return ModelRequestData(
llm=llm,
prompt=prompt,
image=image,
)


def run_vlm2vec(query: Query):
if query["modality"] == "text":
text = query["text"]
Expand Down Expand Up @@ -150,6 +182,7 @@ def main(args: Namespace):
model_example_map = {
"e5_v": run_e5_v,
"vlm2vec": run_vlm2vec,
"colqwen2vlm": run_colqwen2vlm,
}

if __name__ == "__main__":
Expand Down
10 changes: 10 additions & 0 deletions tests/models/embedding/vision_language/test_colqwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# tests/models/embedding/vision_language/test_colqwen2vl.py

import torch
from vllm.model_executor.models.colqwen2_vl import ColQwen2VL

def test_colqwen2vl_embeddings():
model = ColQwen2VL()
dummy_input = torch.rand((1, 3, 224, 224)) # Example input
embeddings = model(dummy_input)
assert embeddings.shape == (1, 128), "Embedding size should be 128."
9 changes: 8 additions & 1 deletion tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from ..utils import fork_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS


from vllm.model_executor.models.colqwen2_vl import ColQwen2VL
from vllm.multimodal import MULTIMODAL_REGISTRY

def test_colqwen2vl_registration():
assert 'ColQwen2VL' in MULTIMODAL_REGISTRY, "ColQwen2VL should be registered."
model = MULTIMODAL_REGISTRY['ColQwen2VL']()

Check failure on line 28 in tests/models/test_registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/models/test_registry.py:28:81: E501 Line too long (82 > 80)
assert isinstance(model, ColQwen2VL), "Failed to instantiate ColQwen2VL."

@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
Expand Down
Loading
Loading