Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/offline_inference/audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ def main(args):
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)

# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
Expand All @@ -226,8 +219,15 @@ def main(args):
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts

outputs = llm.generate(inputs, sampling_params=sampling_params)
# Add LoRA request if applicable
lora_request = (req_data.lora_requests *
args.num_prompts if req_data.lora_requests else None)

outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)

for o in outputs:
generated_text = o.outputs[0].text
Expand Down
34 changes: 24 additions & 10 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
import os
import random
from contextlib import contextmanager
from dataclasses import asdict
from typing import NamedTuple, Optional

Expand Down Expand Up @@ -1055,6 +1056,20 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
return inputs


@contextmanager
def time_counter(enable: bool):
if enable:
import time
start_time = time.time()
yield
elapsed_time = time.time() - start_time
print("-" * 50)
print("-- generate time = {}".format(elapsed_time))
print("-" * 50)
else:
yield


def main(args):
model = args.model_type
if model not in model_example_map:
Expand Down Expand Up @@ -1113,17 +1128,16 @@ def main(args):
},
} for i in range(args.num_prompts)]

if args.time_generate:
import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-" * 50)
print("-- generate time = {}".format(elapsed_time))
print("-" * 50)
# Add LoRA request if applicable
lora_request = (req_data.lora_requests *
args.num_prompts if req_data.lora_requests else None)

else:
outputs = llm.generate(inputs, sampling_params=sampling_params)
with time_counter(args.time_generate):
outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)

print("-" * 50)
for o in outputs:
Expand Down
12 changes: 4 additions & 8 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,6 @@ def run_generate(model, question: str, image_urls: list[str],
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)

# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=256,
stop_token_ids=req_data.stop_token_ids)
Expand All @@ -679,7 +672,9 @@ def run_generate(model, question: str, image_urls: list[str],
"image": req_data.image_data
},
},
sampling_params=sampling_params)
sampling_params=sampling_params,
lora_request=req_data.lora_requests,
)

print("-" * 50)
for o in outputs:
Expand Down Expand Up @@ -724,6 +719,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
}],
sampling_params=sampling_params,
chat_template=req_data.chat_template,
lora_request=req_data.lora_requests,
)

print("-" * 50)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@
max_model_len=4096,
max_num_seqs=2,
task="generate",
# use eager mode for hf runner since phi3v didn't work with flash_attn
hf_model_kwargs={"_attn_implementation": "eager"},
# use sdpa mode for hf runner since phi3v didn't work with flash_attn
hf_model_kwargs={"_attn_implementation": "sdpa"},
use_tokenizer_eos=True,
vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
num_logprobs=10,
Expand Down
97 changes: 80 additions & 17 deletions tests/models/decoder_only/vision_language/test_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import os
import re
from collections.abc import Sequence
from typing import Optional

import librosa
import pytest
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
PromptImageInput, VllmRunner)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

Expand All @@ -29,6 +33,8 @@
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
speech_question = os.path.join(model_path, "examples",
"what_is_shown_in_this_image.wav")
models = [model_path]


Expand Down Expand Up @@ -64,7 +70,8 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str,
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], PromptImageInput]],
inputs: Sequence[tuple[list[str], PromptImageInput,
Optional[PromptAudioInput]]],
model: str,
*,
max_model_len: int,
Expand Down Expand Up @@ -104,28 +111,49 @@ def run_test(
enforce_eager=True,
) as vllm_model:
lora_request = LoRARequest("vision", 1, vision_lora_path)
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
images=images,
audios=audios,
lora_request=lora_request)
for prompts, images, audios in inputs
]

# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
hf_model_kwargs = {"_attn_implementation": "sdpa"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
eos_token_id = hf_model.processor.tokenizer.eos_token_id

hf_processor = hf_model.processor
eos_token_id = hf_processor.tokenizer.eos_token_id

def patch_hf_processor(*args,
text="",
images=None,
audio=None,
sampling_rate=None,
**kwargs):
audios = None
if audio is not None and sampling_rate is not None:
audios = [(audio, sampling_rate)]
return hf_processor(*args,
text=text,
images=images,
audios=audios,
**kwargs)

hf_model.processor = patch_hf_processor

hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
audios=audios,
eos_token_id=eos_token_id,
num_logits_to_keep=0)
for prompts, images in inputs
for prompts, images, audios in inputs
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
Expand All @@ -138,8 +166,6 @@ def run_test(
)


# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
Expand All @@ -151,7 +177,7 @@ def run_test(
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.7, 0.75, 1.0],
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
Expand All @@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
None,
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

run_test(
Expand Down Expand Up @@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.xfail(
reason="Phi-4-MM multi-image inference is divergent with hf model.")
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_model_len: int,
max_tokens: int, num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]

inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
(
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors],
None,
),
]

run_test(
Expand All @@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
mm_limit=2,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
max_model_len: int, max_tokens: int,
num_logprobs: int) -> None:

# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")

inputs_vision_speech = [
(
["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
[image],
[audio],
),
]

run_test(
hf_runner,
vllm_runner,
inputs_vision_speech,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)