@@ -28,6 +28,7 @@ class Qwen3VLForConditionalGeneration(BaseComposeModel):
2828
2929 def __init__ (self , config : Qwen3VLBaseConfig ):
3030 super ().__init__ (config ) # type: ignore[arg-type]
31+ self .only_llm_forward = config .only_llm_forward
3132
3233 # if type(self.language_model) is Qwen3MoE:
3334 # # TODO(YHC): This is a hack to make the language model compatible with HF
@@ -143,8 +144,9 @@ def forward(
143144 sequence_parallel_mesh = seq_ctx .sequence_parallel_mesh
144145
145146 inputs_embeds = self .language_model .embed_tokens (input_ids ) # type: ignore
146-
147+
147148 if pixel_values is not None :
149+ assert self .only_llm_forward is False , "only_llm_forward is True, but pixel_values is not None. Please check your config setting."
148150 assert image_grid_thw is not None
149151 assert input_ids is not None
150152 visual_embeds , deepstack_visual_embeds = self .get_visual_features (pixel_values ,
@@ -170,12 +172,13 @@ def forward(
170172 deepstack_visual_embeds = None
171173 visual_pos_masks = None
172174 else :
173- pixel_values_dump = torch .randn (4 , 1536 , device = inputs_embeds .device , dtype = inputs_embeds .dtype )
174- image_grid_thw = torch .tensor ([[1 , 2 , 2 ]], device = inputs_embeds .device )
175- viusal_embeds , deepstack_visual_embeds = self .get_visual_features (pixel_values_dump , image_grid_thw )
176- inputs_embeds = inputs_embeds + viusal_embeds .sum () * 0.0
177- for deepstack_visual_embed in deepstack_visual_embeds :
178- inputs_embeds = inputs_embeds + deepstack_visual_embed .sum () * 0.0
175+ if not self .only_llm_forward :
176+ pixel_values_dump = torch .randn (4 , 1536 , device = inputs_embeds .device , dtype = inputs_embeds .dtype )
177+ image_grid_thw = torch .tensor ([[1 , 2 , 2 ]], device = inputs_embeds .device )
178+ viusal_embeds , deepstack_visual_embeds = self .get_visual_features (pixel_values_dump , image_grid_thw )
179+ inputs_embeds = inputs_embeds + viusal_embeds .sum () * 0.0
180+ for deepstack_visual_embed in deepstack_visual_embeds :
181+ inputs_embeds = inputs_embeds + deepstack_visual_embed .sum () * 0.0
179182
180183 deepstack_visual_embeds = None
181184 visual_pos_masks = None
0 commit comments