Skip to content

Commit 824fd36

Browse files
authored
[Bugfix][NPU] Add _model_forward for ModelRunner (vllm-project#505)
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
1 parent 6954d5b commit 824fd36

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

vllm_omni/model_executor/stage_configs/npu/qwen3_omni_moe.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ stage_args:
7979
engine_output_type: audio # Final output: audio waveform
8080
gpu_memory_utilization: 0.1
8181
distributed_executor_backend: "mp"
82-
max_num_batched_tokens: 4096
82+
max_num_batched_tokens: 1000000
8383
hf_config_name: thinker_config
8484
engine_input_source: [1]
8585
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav

vllm_omni/worker/npu/npu_ar_model_runner.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,21 +1199,11 @@ def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
11991199
input_ids, positions,
12001200
intermediate_tensors,
12011201
inputs_embeds):
1202-
model_kwargs_extra = self._build_model_kwargs_extra()
1203-
1204-
runtime_info = model_kwargs_extra.get("runtime_additional_information", [])
1205-
if runtime_info:
1206-
for i, info in enumerate(runtime_info):
1207-
if info:
1208-
logger.debug(f"[OMNI] req[{i}] runtime_additional_information keys: {list(info.keys())}")
1209-
1210-
assert self.model is not None
1211-
hidden_states = self.model(input_ids=input_ids,
1202+
hidden_states = self._model_forward(input_ids=input_ids,
12121203
positions=positions,
12131204
intermediate_tensors=intermediate_tensors,
12141205
inputs_embeds=inputs_embeds,
1215-
**self._init_model_kwargs(),
1216-
**model_kwargs_extra)
1206+
**self._init_model_kwargs())
12171207

12181208
forward_context = get_forward_context()
12191209
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \

vllm_omni/worker/npu/npu_model_runner.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import math
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import numpy as np
88
import torch
@@ -636,3 +636,32 @@ def _collect_additional_information_for_prefill(
636636
)
637637
start_offset = int(self.query_start_loc.cpu[req_index])
638638
self.inputs_embeds[start_offset : start_offset + overlay_len].copy_(src)
639+
640+
def _model_forward(
641+
self,
642+
input_ids: torch.Tensor | None = None,
643+
positions: torch.Tensor | None = None,
644+
intermediate_tensors: IntermediateTensors | None = None,
645+
inputs_embeds: torch.Tensor | None = None,
646+
**model_kwargs: dict[str, Any],
647+
):
648+
"""Inject omni-specific kwargs into forward and cache model output"""
649+
model_kwargs_extra = self._build_model_kwargs_extra()
650+
651+
runtime_info = model_kwargs_extra.get("runtime_additional_information", [])
652+
if runtime_info:
653+
for i, info in enumerate(runtime_info):
654+
if info:
655+
logger.debug(f"[OMNI] req[{i}] runtime_additional_information keys: {list(info.keys())}")
656+
657+
model_output = super()._model_forward(
658+
input_ids=input_ids,
659+
positions=positions,
660+
intermediate_tensors=intermediate_tensors,
661+
inputs_embeds=inputs_embeds,
662+
**model_kwargs,
663+
**model_kwargs_extra,
664+
)
665+
# Cache model output so later sample_tokens can consume multimodal results.
666+
self._omni_last_model_output = model_output
667+
return model_output

0 commit comments

Comments
 (0)