@@ -146,6 +146,12 @@ def _validate_input(
146146 video_processor : BaseImageProcessor = getattr (
147147 processor , "video_processor" , getattr (processor , "image_processor" , None )
148148 )
149+ if image_processor is None and video_processor is None : # hack for qwen2_5_omni
150+ image_processor , video_processor = (
151+ getattr (processor , "omni_processor" , None ),
152+ getattr (processor , "omni_processor" , None ),
153+ )
154+
149155 feature_extractor : SequenceFeatureExtractor = getattr (processor , "feature_extractor" , None )
150156 if len (images ) != 0 and self .image_token is None :
151157 raise ValueError (
@@ -1104,6 +1110,186 @@ def get_mm_inputs(
11041110 return self ._get_mm_inputs (images , videos , audios , processor )
11051111
11061112
1113+ class Qwen2OmniPlugin (BasePlugin ):
1114+ @override
1115+ def _get_mm_inputs (
1116+ self ,
1117+ images : list ["ImageInput" ],
1118+ videos : list ["VideoInput" ],
1119+ audios : list ["AudioInput" ],
1120+ processor : "MMProcessor" ,
1121+ imglens : Optional [list [int ]] = None ,
1122+ ) -> dict [str , "torch.Tensor" ]:
1123+ mm_inputs = {}
1124+ if len (images ) != 0 :
1125+ image_processor : BaseImageProcessor = getattr (processor , "omni_processor" , None ) # FIXME
1126+ images = self ._regularize_images (
1127+ images ,
1128+ image_max_pixels = getattr (processor , "image_max_pixels" , 768 * 768 ),
1129+ image_min_pixels = getattr (processor , "image_min_pixels" , 32 * 32 ),
1130+ )
1131+ if imglens is not None :
1132+ images = _make_batched_images (images , imglens )
1133+
1134+ image_processor_kwargs = {}
1135+ mm_inputs .update (image_processor (images , return_tensors = "pt" , ** image_processor_kwargs ))
1136+
1137+ if len (videos ) != 0 :
1138+ video_processor : BaseImageProcessor = getattr (
1139+ processor , "video_processor" , getattr (processor , "omni_processor" , None )
1140+ )
1141+ videos = self ._regularize_videos (
1142+ videos ,
1143+ image_max_pixels = getattr (processor , "video_max_pixels" , 256 * 256 ),
1144+ image_min_pixels = getattr (processor , "video_min_pixels" , 16 * 16 ),
1145+ video_fps = getattr (processor , "video_fps" , 2.0 ),
1146+ video_maxlen = getattr (processor , "video_maxlen" , 128 ),
1147+ )
1148+ if "videos" in inspect .signature (video_processor .preprocess ).parameters : # for qwen2_vl and video_llava
1149+ mm_inputs .update (video_processor (images = None , videos = videos , return_tensors = "pt" ))
1150+ fps = [2.0 ] * len (videos ) # FIXME hardcode
1151+ video_second_per_grid = [fps [i ] / video_processor .temporal_patch_size for i in range (len (fps ))]
1152+ mm_inputs ["video_second_per_grid" ] = torch .tensor (video_second_per_grid )
1153+
1154+ else :
1155+ raise NotImplementedError
1156+
1157+ if len (audios ) != 0 :
1158+ feature_extractor : SequenceFeatureExtractor = getattr (processor , "feature_extractor" , None )
1159+ audios = self ._regularize_audios (
1160+ audios ,
1161+ sampling_rate = getattr (feature_extractor , "sampling_rate" , 16000 ),
1162+ )
1163+ mm_inputs .update (
1164+ feature_extractor (
1165+ audios ,
1166+ sampling_rate = getattr (feature_extractor , "sampling_rate" , 16000 ),
1167+ return_attention_mask = True ,
1168+ padding = "max_length" ,
1169+ return_tensors = "pt" ,
1170+ )
1171+ )
1172+ mm_inputs ["feature_attention_mask" ] = mm_inputs .pop ("attention_mask" ) # prevent conflicts
1173+
1174+ return mm_inputs
1175+
1176+ @override
1177+ def process_messages (
1178+ self ,
1179+ messages : list [dict [str , str ]],
1180+ images : list ["ImageInput" ],
1181+ videos : list ["VideoInput" ],
1182+ audios : list ["AudioInput" ],
1183+ processor : Optional ["MMProcessor" ],
1184+ ) -> list [dict [str , str ]]:
1185+ self ._validate_input (processor , images , videos , audios )
1186+ messages = deepcopy (messages )
1187+ if self .expand_mm_tokens :
1188+ mm_inputs = self ._get_mm_inputs (images , videos , audios , processor )
1189+ num_audio_tokens , num_image_tokens , num_video_tokens = 0 , 0 , 0
1190+ use_audio_in_video = getattr (processor , "use_audio_in_video" , False )
1191+
1192+ # get length or size from mm_inputs
1193+ if "feature_attention_mask" in mm_inputs :
1194+ input_lengths = (mm_inputs ["feature_attention_mask" ].sum (- 1 ).numpy () - 1 ) // 2 + 1
1195+ audio_lengths = (input_lengths - 2 ) // 2 + 1
1196+ if mm_inputs .get ("image_grid_thw" , None ) is not None :
1197+ image_grid_thw = mm_inputs ["image_grid_thw" ]
1198+ merge_length = processor .omni_processor .merge_size ** 2
1199+ if mm_inputs .get ("video_grid_thw" , None ) is not None :
1200+ video_grid_thw = mm_inputs ["video_grid_thw" ]
1201+ merge_length = processor .omni_processor .merge_size ** 2
1202+
1203+ if use_audio_in_video :
1204+ assert audio_lengths is not None , "audio_lengths should be exist when use_audio_in_video is `True`"
1205+ assert mm_inputs .get ("video_grid_thw" , None ) is not None , (
1206+ "video_grid_thw should be exist when use_audio_in_video is `True`"
1207+ )
1208+ positions_list = []
1209+ for i , message in enumerate (messages ): # get multimodal index when use_audio
1210+ positions = []
1211+ for special_token in [self .audio_token , self .image_token , self .video_token ]:
1212+ start = 0
1213+ while True :
1214+ pos = message [i ].find (special_token , start )
1215+ if pos == - 1 :
1216+ break
1217+ positions .append ((pos , special_token ))
1218+ start = pos + len (special_token )
1219+ positions_list .append (positions .sort (key = lambda x : x [0 ]))
1220+
1221+ for message in messages :
1222+ content = message ["content" ]
1223+ # separate with audio-video
1224+ while IMAGE_PLACEHOLDER in content :
1225+ image_token_replace_length = image_grid_thw [num_image_tokens ].prod () // merge_length
1226+ content = content .replace (
1227+ IMAGE_PLACEHOLDER ,
1228+ f"<|vision_bos|>{ self .image_token * image_token_replace_length } <|vision_eos|>" ,
1229+ 1 ,
1230+ )
1231+ num_image_tokens += 1
1232+
1233+ if not use_audio_in_video :
1234+ while AUDIO_PLACEHOLDER in content :
1235+ audio_token_replace_length = audio_lengths [num_audio_tokens ]
1236+ content = content .replace (
1237+ AUDIO_PLACEHOLDER ,
1238+ f"<|audio_bos|>{ self .audio_token * audio_token_replace_length } <|audio_eos|>" ,
1239+ 1 ,
1240+ )
1241+ num_audio_tokens += 1
1242+ # TODO handle video_input and use_audio_in_video
1243+ while VIDEO_PLACEHOLDER in content :
1244+ video_replace_length = video_grid_thw [num_video_tokens ].prod () // merge_length
1245+ content = content .replace (
1246+ VIDEO_PLACEHOLDER , f"<|vision_bos|>{ self .video_token * video_replace_length } <|vision_eos|>" , 1
1247+ )
1248+ num_video_tokens += 1
1249+ else : # if use the audio of video # deal video token and audio token togather
1250+ while VIDEO_PLACEHOLDER in content :
1251+ audio_t_index = torch .arange (audio_lengths [num_audio_tokens ])
1252+ video_t_index = (
1253+ torch .arange (video_grid_thw [num_video_tokens ][0 ])
1254+ .view (- 1 , 1 , 1 )
1255+ .expand (
1256+ - 1 ,
1257+ video_grid_thw [num_video_tokens ][1 ] // self .omni_processor .merge_size ,
1258+ video_grid_thw [num_video_tokens ][2 ] // self .omni_processor .merge_size ,
1259+ )
1260+ .flatten ()
1261+ * mm_inputs ["video_second_per_grid" ][num_video_tokens ]
1262+ * 25 # FIXME hardcode of position_id_per_seconds=25
1263+ ).long ()
1264+ t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
1265+ video_chunk_indices = processor .get_chunked_index (video_t_index , t_ntoken_per_chunk )
1266+ audio_chunk_indices = self .get_chunked_index (audio_t_index , t_ntoken_per_chunk )
1267+ placeholder_string = ""
1268+ for j in range (max (len (video_chunk_indices ), len (audio_chunk_indices ))):
1269+ video_chunk_index = video_chunk_indices [j ] if j < len (video_chunk_indices ) else None
1270+ audio_chunk_index = audio_chunk_indices [j ] if j < len (audio_chunk_indices ) else None
1271+ placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
1272+ if video_chunk_index is not None :
1273+ placeholder_string += self .video_token * (video_chunk_index [1 ] - video_chunk_index [0 ])
1274+ if audio_chunk_index is not None :
1275+ placeholder_string += self .audio_token * (audio_chunk_index [1 ] - audio_chunk_index [0 ])
1276+ placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
1277+ content = content .replace (VIDEO_PLACEHOLDER , placeholder_string , 1 )
1278+ content = content .replace (AUDIO_PLACEHOLDER , "" , 1 )
1279+ num_audio_tokens += 1
1280+ num_video_tokens += 1
1281+ message ["content" ] = content
1282+
1283+ if len (audios ) != num_audio_tokens :
1284+ raise ValueError (f"The number of audios does not match the number of { AUDIO_PLACEHOLDER } tokens." )
1285+ if len (images ) != num_image_tokens :
1286+ raise ValueError (f"The number of images does not match the number of { IMAGE_PLACEHOLDER } tokens." )
1287+ if len (videos ) != num_video_tokens :
1288+ raise ValueError (f"The number of videos does not match the number of { VIDEO_PLACEHOLDER } tokens." )
1289+
1290+ return messages
1291+
1292+
11071293@dataclass
11081294class Qwen2VLPlugin (BasePlugin ):
11091295 @override
@@ -1328,6 +1514,7 @@ def process_messages(
13281514 "paligemma" : PaliGemmaPlugin ,
13291515 "pixtral" : PixtralPlugin ,
13301516 "qwen2_audio" : Qwen2AudioPlugin ,
1517+ "qwen2_omni" : Qwen2OmniPlugin ,
13311518 "qwen2_vl" : Qwen2VLPlugin ,
13321519 "video_llava" : VideoLlavaPlugin ,
13331520}
0 commit comments