Skip to content

Commit 696259c

Browse files
[Core] Automatically cast multi-modal input dtype (#18756)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 6b6d496 commit 696259c

16 files changed

+91
-44
lines changed

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,7 @@ def _call_hf_processor(
210210
dict(prompt=prompt, **mm_data),
211211
mm_kwargs,
212212
)
213-
target_dtype = self.info.ctx.model_config.dtype
214-
pixel_values = processed_outputs.pop("pixel_values").to(
215-
target_dtype)
213+
pixel_values = processed_outputs["pixel_values"]
216214
# split pixel values into patches corresponding to each image
217215
images_spatial_crop = processed_outputs["images_spatial_crop"]
218216
patches_per_image = [

vllm/model_executor/models/gemma3_mm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,6 @@ def _call_hf_processor(
263263
mm_data,
264264
mm_kwargs,
265265
)
266-
if "pixel_values" in processed_outputs:
267-
# Cast pixel values to model dtype already here,
268-
# so we need to transfer less data to the GPU
269-
processed_outputs["pixel_values"] = processed_outputs[
270-
"pixel_values"].to(self.info.ctx.model_config.dtype)
271266

272267
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
273268
if (images := mm_data.get("images")) is not None:

vllm/multimodal/inputs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,11 +746,17 @@ def as_kwargs(
746746
batched_inputs: BatchedTensorInputs,
747747
*,
748748
device: torch.types.Device,
749+
dtype: Optional[torch.dtype] = None,
749750
) -> BatchedTensorInputs:
750751
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
751752

753+
def maybe_cast_dtype(x: torch.Tensor):
754+
# This mimics the behavior of transformers.BatchFeature
755+
return x.to(dtype=dtype) if x.is_floating_point() else x
756+
752757
json_mapped = json_map_leaves(
753-
lambda x: x.to(device, non_blocking=True),
758+
# NOTE: Cast the dtype before sending it to device
759+
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
754760
json_inputs,
755761
)
756762

vllm/spec_decode/draft_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,11 @@ def execute_model(
294294
inputs_embeds=None,
295295
positions=model_input.input_positions,
296296
intermediate_tensors=intermediate_tensors,
297-
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
298-
device=self.device),
297+
**MultiModalKwargs.as_kwargs(
298+
multi_modal_kwargs,
299+
dtype=self.model_runner.model_config.dtype,
300+
device=self.device,
301+
),
299302
**model_execute_kwargs,
300303
)
301304

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
929929
encoder_outputs = []
930930
for grouped_mm_inputs in grouped_mm_inputs_list:
931931
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
932-
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
933-
device=self.device)
932+
batched_mm_inputs = MultiModalKwargs.as_kwargs(
933+
batched_mm_inputs,
934+
dtype=self.model_config.dtype,
935+
device=self.device,
936+
)
934937

935938
# Run the encoder.
936939
# `curr_group_outputs` is either of the following:
@@ -1874,7 +1877,10 @@ def profile_run(self) -> None:
18741877
batched_dummy_mm_inputs = MultiModalKwargs.batch(
18751878
[dummy_mm_kwargs] * max_num_mm_items)
18761879
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
1877-
batched_dummy_mm_inputs, device=self.device)
1880+
batched_dummy_mm_inputs,
1881+
dtype=self.model_config.dtype,
1882+
device=self.device,
1883+
)
18781884

18791885
# Run multimodal encoder.
18801886
dummy_encoder_outputs = self.model.get_multimodal_embeddings(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
652652
encoder_outputs = []
653653
for grouped_mm_inputs in grouped_mm_inputs_list:
654654
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
655-
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
656-
device=self.device)
655+
batched_mm_inputs = MultiModalKwargs.as_kwargs(
656+
batched_mm_inputs,
657+
dtype=self.model_config.dtype,
658+
device=self.device,
659+
)
657660

658661
# Run the encoder.
659662
# `curr_group_outputs` is either of the following:
@@ -1435,8 +1438,11 @@ def _get_mm_dummy_batch(self, modality: str,
14351438

14361439
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
14371440
batch_size)
1438-
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
1439-
device=self.device)
1441+
return MultiModalKwargs.as_kwargs(
1442+
batched_dummy_mm_inputs,
1443+
dtype=self.model_config.dtype,
1444+
device=self.device,
1445+
)
14401446

14411447

14421448
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:

vllm/worker/cpu_enc_dec_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,11 @@ def execute_model(
297297
model_input.encoder_input_tokens,
298298
"encoder_positions":
299299
model_input.encoder_input_positions,
300-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
301-
device=self.device),
300+
**MultiModalKwargs.as_kwargs(
301+
model_input.multi_modal_kwargs or {},
302+
dtype=self.model_config.dtype,
303+
device=self.device,
304+
),
302305
"intermediate_tensors":
303306
intermediate_tensors,
304307
}

vllm/worker/cpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,10 @@ def execute_model(
628628
multimodal_kwargs = {}
629629
if model_input.multi_modal_kwargs is not None:
630630
multimodal_kwargs = MultiModalKwargs.as_kwargs(
631-
model_input.multi_modal_kwargs, device=self.device)
631+
model_input.multi_modal_kwargs,
632+
dtype=self.model_config.dtype,
633+
device=self.device,
634+
)
632635
execute_model_kwargs = {}
633636
if previous_hidden_states is not None:
634637
execute_model_kwargs.update(

vllm/worker/cpu_pooling_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def execute_model(
5050
model_input.input_tokens,
5151
"positions":
5252
model_input.input_positions,
53-
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
54-
device=self.device),
53+
**MultiModalKwargs.as_kwargs(
54+
model_input.multi_modal_kwargs or {},
55+
dtype=self.model_config.dtype,
56+
device=self.device,
57+
),
5558
**cross_enc_kwargs,
5659
"intermediate_tensors":
5760
intermediate_tensors,

vllm/worker/enc_dec_model_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,13 @@ def execute_model(
202202
encoder_input_ids=model_input.encoder_input_tokens,
203203
encoder_positions=model_input.encoder_input_positions,
204204
intermediate_tensors=intermediate_tensors,
205-
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
206-
device=self.device),
207-
**seqlen_agnostic_kwargs)
205+
**MultiModalKwargs.as_kwargs(
206+
multi_modal_kwargs,
207+
dtype=self.model_config.dtype,
208+
device=self.device,
209+
),
210+
**seqlen_agnostic_kwargs,
211+
)
208212

209213
logits = self.model.compute_logits(hidden_or_intermediate_states,
210214
model_input.sampling_metadata)

0 commit comments

Comments
 (0)