@@ -1602,34 +1602,30 @@ def process_messages(
16021602 processor : Optional ["MMProcessor" ],
16031603 ) -> list [dict [str , str ]]:
16041604 self ._validate_input (processor , images , videos , audios )
1605+ num_image_tokens , num_video_tokens , num_audio_tokens = 0 , 0 , 0
16051606 messages = deepcopy (messages )
1607+ image_processor : BaseImageProcessor = getattr (processor , "image_processor" , None )
1608+
1609+ merge_length = processor .image_processor .merge_size ** 2
1610+ use_audio_in_video = getattr (processor , "use_audio_in_video" , False )
16061611 if self .expand_mm_tokens :
16071612 mm_inputs = self ._get_mm_inputs (images , videos , audios , processor )
1613+ image_grid_thw = mm_inputs .get ("image_grid_thw" , [])
1614+ video_grid_thw = mm_inputs .get ("video_grid_thw" , [])
1615+ if "feature_attention_mask" in mm_inputs :
1616+ input_lengths = (mm_inputs ["feature_attention_mask" ].sum (- 1 ).numpy () - 1 ) // 2 + 1
1617+ audio_lengths = (input_lengths - 2 ) // 2 + 1
16081618 else :
16091619 mm_inputs = {}
1620+ image_grid_thw = [None ] * len (images )
1621+ video_grid_thw = [None ] * len (videos )
1622+ audio_lengths = [None ] * len (audios )
16101623
1611- image_processor : BaseImageProcessor = getattr (processor , "image_processor" , None )
1612- num_audio_tokens , num_image_tokens , num_video_tokens = 0 , 0 , 0
1613- use_audio_in_video = getattr (processor , "use_audio_in_video" , False )
1614-
1615- # get length or size from mm_inputs
1616- if "feature_attention_mask" in mm_inputs :
1617- input_lengths = (mm_inputs ["feature_attention_mask" ].sum (- 1 ).numpy () - 1 ) // 2 + 1
1618- audio_lengths = (input_lengths - 2 ) // 2 + 1
1619-
1620- if mm_inputs .get ("image_grid_thw" , None ) is not None :
1621- image_grid_thw = mm_inputs ["image_grid_thw" ]
1622- merge_length = processor .image_processor .merge_size ** 2
1623-
1624- if mm_inputs .get ("video_grid_thw" , None ) is not None :
1625- video_grid_thw = mm_inputs ["video_grid_thw" ]
1626- merge_length = processor .image_processor .merge_size ** 2
1627-
1628- if use_audio_in_video :
1629- if audio_lengths is None :
1624+ if self .expand_mm_tokens and use_audio_in_video :
1625+ if "feature_attention_mask" not in mm_inputs :
16301626 raise ValueError ("audio_lengths should exist when use_audio_in_video is `True`." )
16311627
1632- if mm_inputs . get ( "video_grid_thw" , None ) is None :
1628+ if "video_grid_thw" not in mm_inputs :
16331629 raise ValueError ("video_grid_thw should exist when use_audio_in_video is `True`." )
16341630
16351631 positions_list = []
@@ -1653,11 +1649,9 @@ def process_messages(
16531649 if num_image_tokens >= len (images ):
16541650 raise ValueError (f"`len(images)` is less than the number of { IMAGE_PLACEHOLDER } tokens." )
16551651
1656- image_token_replace_length = image_grid_thw [num_image_tokens ].prod () // merge_length
1652+ image_seqlen = image_grid_thw [num_image_tokens ].prod () // merge_length if self . expand_mm_tokens else 1
16571653 content = content .replace (
1658- IMAGE_PLACEHOLDER ,
1659- f"<|vision_bos|>{ self .image_token * image_token_replace_length } <|vision_eos|>" ,
1660- 1 ,
1654+ IMAGE_PLACEHOLDER , f"<|vision_bos|>{ self .image_token * image_seqlen } <|vision_eos|>" , 1
16611655 )
16621656 num_image_tokens += 1
16631657
@@ -1666,11 +1660,9 @@ def process_messages(
16661660 if num_audio_tokens >= len (audios ):
16671661 raise ValueError (f"`len(audios)` is less than the number of { AUDIO_PLACEHOLDER } tokens." )
16681662
1669- audio_token_replace_length = audio_lengths [num_audio_tokens ]
1663+ audio_seqlen = audio_lengths [num_audio_tokens ] if self . expand_mm_tokens else 1
16701664 content = content .replace (
1671- AUDIO_PLACEHOLDER ,
1672- f"<|audio_bos|>{ self .audio_token * audio_token_replace_length } <|audio_eos|>" ,
1673- 1 ,
1665+ AUDIO_PLACEHOLDER , f"<|audio_bos|>{ self .audio_token * audio_seqlen } <|audio_eos|>" , 1
16741666 )
16751667 num_audio_tokens += 1
16761668
@@ -1679,9 +1671,9 @@ def process_messages(
16791671 if num_video_tokens >= len (videos ):
16801672 raise ValueError (f"`len(videos)` is less than the number of { VIDEO_PLACEHOLDER } tokens." )
16811673
1682- video_replace_length = video_grid_thw [num_video_tokens ].prod () // merge_length
1674+ video_seqlen = video_grid_thw [num_video_tokens ].prod () // merge_length if self . expand_mm_tokens else 1
16831675 content = content .replace (
1684- VIDEO_PLACEHOLDER , f"<|vision_bos|>{ self .video_token * video_replace_length } <|vision_eos|>" , 1
1676+ VIDEO_PLACEHOLDER , f"<|vision_bos|>{ self .video_token * video_seqlen } <|vision_eos|>" , 1
16851677 )
16861678 num_video_tokens += 1
16871679
0 commit comments