Skip to content

Commit 0aba11b

Browse files
committed
add only_llm_forward
1 parent 543e0f4 commit 0aba11b

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Qwen3VLForConditionalGeneration(BaseComposeModel):
2828

2929
def __init__(self, config: Qwen3VLBaseConfig):
3030
super().__init__(config) # type: ignore[arg-type]
31+
self.only_llm_forward = config.only_llm_forward
3132

3233
# if type(self.language_model) is Qwen3MoE:
3334
# # TODO(YHC): This is a hack to make the language model compatible with HF
@@ -143,8 +144,9 @@ def forward(
143144
sequence_parallel_mesh = seq_ctx.sequence_parallel_mesh
144145

145146
inputs_embeds = self.language_model.embed_tokens(input_ids) # type: ignore
146-
147+
147148
if pixel_values is not None:
149+
assert self.only_llm_forward is False, "only_llm_forward is True, but pixel_values is not None. Please check your config setting."
148150
assert image_grid_thw is not None
149151
assert input_ids is not None
150152
visual_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values,
@@ -170,12 +172,13 @@ def forward(
170172
deepstack_visual_embeds = None
171173
visual_pos_masks = None
172174
else:
173-
pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
174-
image_grid_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device)
175-
viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, image_grid_thw)
176-
inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0
177-
for deepstack_visual_embed in deepstack_visual_embeds:
178-
inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0
175+
if not self.only_llm_forward:
176+
pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
177+
image_grid_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device)
178+
viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, image_grid_thw)
179+
inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0
180+
for deepstack_visual_embed in deepstack_visual_embeds:
181+
inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0
179182

180183
deepstack_visual_embeds = None
181184
visual_pos_masks = None

xtuner/v1/model/compose/qwen3_vl/qwen3_vl_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class Qwen3VLBaseConfig(BaseComposeConfig):
8787
freeze_vision: bool = False
8888
freeze_projector: bool = False
8989
freeze_language: bool = False
90+
# If true, skip the forward of vit+projector. Only enable when the whole training process is pure text task.
91+
only_llm_forward: bool = False
9092

9193
def build(self):
9294
from .modeling_qwen3_vl import Qwen3VLForConditionalGeneration

0 commit comments

Comments
 (0)