Skip to content

Commit 00b5c05

Browse files
authored
[data] fix qwen2 omni plugin (#7875)
1 parent 1bd319d commit 00b5c05

1 file changed

Lines changed: 22 additions & 30 deletions

File tree

src/llamafactory/data/mm_plugin.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,34 +1602,30 @@ def process_messages(
16021602
processor: Optional["MMProcessor"],
16031603
) -> list[dict[str, str]]:
16041604
self._validate_input(processor, images, videos, audios)
1605+
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
16051606
messages = deepcopy(messages)
1607+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1608+
1609+
merge_length = processor.image_processor.merge_size**2
1610+
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
16061611
if self.expand_mm_tokens:
16071612
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1613+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
1614+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
1615+
if "feature_attention_mask" in mm_inputs:
1616+
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
1617+
audio_lengths = (input_lengths - 2) // 2 + 1
16081618
else:
16091619
mm_inputs = {}
1620+
image_grid_thw = [None] * len(images)
1621+
video_grid_thw = [None] * len(videos)
1622+
audio_lengths = [None] * len(audios)
16101623

1611-
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1612-
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
1613-
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
1614-
1615-
# get length or size from mm_inputs
1616-
if "feature_attention_mask" in mm_inputs:
1617-
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
1618-
audio_lengths = (input_lengths - 2) // 2 + 1
1619-
1620-
if mm_inputs.get("image_grid_thw", None) is not None:
1621-
image_grid_thw = mm_inputs["image_grid_thw"]
1622-
merge_length = processor.image_processor.merge_size**2
1623-
1624-
if mm_inputs.get("video_grid_thw", None) is not None:
1625-
video_grid_thw = mm_inputs["video_grid_thw"]
1626-
merge_length = processor.image_processor.merge_size**2
1627-
1628-
if use_audio_in_video:
1629-
if audio_lengths is None:
1624+
if self.expand_mm_tokens and use_audio_in_video:
1625+
if "feature_attention_mask" not in mm_inputs:
16301626
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
16311627

1632-
if mm_inputs.get("video_grid_thw", None) is None:
1628+
if "video_grid_thw" not in mm_inputs:
16331629
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
16341630

16351631
positions_list = []
@@ -1653,11 +1649,9 @@ def process_messages(
16531649
if num_image_tokens >= len(images):
16541650
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
16551651

1656-
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
1652+
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
16571653
content = content.replace(
1658-
IMAGE_PLACEHOLDER,
1659-
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1660-
1,
1654+
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
16611655
)
16621656
num_image_tokens += 1
16631657

@@ -1666,11 +1660,9 @@ def process_messages(
16661660
if num_audio_tokens >= len(audios):
16671661
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
16681662

1669-
audio_token_replace_length = audio_lengths[num_audio_tokens]
1663+
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
16701664
content = content.replace(
1671-
AUDIO_PLACEHOLDER,
1672-
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1673-
1,
1665+
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
16741666
)
16751667
num_audio_tokens += 1
16761668

@@ -1679,9 +1671,9 @@ def process_messages(
16791671
if num_video_tokens >= len(videos):
16801672
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
16811673

1682-
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
1674+
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
16831675
content = content.replace(
1684-
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
1676+
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
16851677
)
16861678
num_video_tokens += 1
16871679

0 commit comments

Comments
 (0)