Skip to content

Commit 185c76f

Browse files
authored
[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
1 parent 468eea6 commit 185c76f

10 files changed

Lines changed: 348 additions & 2 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
261261
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
262262
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
263263
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
264+
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni |
264265
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
265266
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
266267
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
263263
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
264264
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
265265
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
266+
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni |
266267
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
267268
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
268269
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |

scripts/lora_part_merge.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
2+
#
3+
# This code is based on the HuggingFace's PEFT library.
4+
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
import os
18+
import shutil
19+
20+
import fire
21+
from peft import PeftModel
22+
from transformers import AutoModel, AutoProcessor, AutoTokenizer
23+
24+
25+
def merge_lora(
26+
base_model_path: str,
27+
lora_checkpoint_path: str,
28+
extra_file: str = "spk_dict.pt",
29+
submodule_name: str = "thinker",
30+
save_path: str = "./merged_model_checkpoint",
31+
):
32+
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
33+
34+
for a specified submodule, and save the final merged model along with its configurations.
35+
36+
Args:
37+
base_model_path (str): Path to the original model directory.
38+
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
39+
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
40+
submodule_name (str): Name of the submodule to merge (default: "thinker").
41+
save_path (str): Directory where the merged model and configurations will be saved.
42+
"""
43+
# 1. Load the original model, tokenizer, and processor
44+
model = AutoModel.from_pretrained(base_model_path)
45+
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
46+
47+
try:
48+
processor = AutoProcessor.from_pretrained(base_model_path)
49+
except Exception:
50+
print("Processor configuration not found, skipping processor load.")
51+
processor = None
52+
53+
print("Successfully loaded the original model, tokenizer, and processor (if available).")
54+
55+
# 2. Extract the submodule to be merged (e.g., model.thinker)
56+
if not hasattr(model, submodule_name):
57+
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
58+
base_submodule = getattr(model, submodule_name)
59+
print(f"Successfully extracted submodule: {submodule_name}.")
60+
61+
# 3. Load the LoRA weights onto the extracted submodule
62+
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
63+
print("LoRA weights loaded successfully.")
64+
65+
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
66+
merged_submodule = lora_model.merge_and_unload()
67+
print("LoRA weights merged successfully.")
68+
69+
# 5. Replace the original submodule with the merged submodule in the model
70+
setattr(model, submodule_name, merged_submodule)
71+
72+
# 6. Save the final merged model along with the tokenizer and processor configuration
73+
model.save_pretrained(save_path)
74+
tokenizer.save_pretrained(save_path)
75+
if processor is not None:
76+
processor.save_pretrained(save_path)
77+
78+
print(f"Merged model and configuration saved to {save_path}.")
79+
80+
source_file = os.path.join(base_model_path, extra_file)
81+
target_file = os.path.join(save_path, extra_file)
82+
if os.path.exists(source_file):
83+
shutil.copy(source_file, target_file)
84+
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
85+
else:
86+
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
87+
88+
89+
if __name__ == "__main__":
90+
fire.Fire(merge_lora)

src/llamafactory/data/collator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,27 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
190190
"video_grid_thw": mm_inputs.get("video_grid_thw"),
191191
"attention_mask": features["attention_mask"],
192192
}
193-
if "second_per_grid_ts" in mm_inputs:
193+
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
194194
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
195195

196-
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
196+
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni
197+
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
198+
if feature_attention_mask is not None:
199+
audio_feature_lengths = torch.sum(
200+
feature_attention_mask, dim=1
201+
) # FIXME need to get video image lengths
202+
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
203+
204+
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
205+
# avoid conflict
206+
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
207+
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
208+
features["position_ids"], features["rope_deltas"] = (
209+
new_position_ids.clone(),
210+
rope_deltas - delta0,
211+
) # avoid inplace operation FIXME
212+
else: # for qwen2vl
213+
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
197214

198215
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
199216
cross_attention_mask = mm_inputs.pop("cross_attention_mask")

src/llamafactory/data/mm_plugin.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def _validate_input(
146146
video_processor: BaseImageProcessor = getattr(
147147
processor, "video_processor", getattr(processor, "image_processor", None)
148148
)
149+
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
150+
image_processor, video_processor = (
151+
getattr(processor, "omni_processor", None),
152+
getattr(processor, "omni_processor", None),
153+
)
154+
149155
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
150156
if len(images) != 0 and self.image_token is None:
151157
raise ValueError(
@@ -1104,6 +1110,186 @@ def get_mm_inputs(
11041110
return self._get_mm_inputs(images, videos, audios, processor)
11051111

11061112

1113+
class Qwen2OmniPlugin(BasePlugin):
1114+
@override
1115+
def _get_mm_inputs(
1116+
self,
1117+
images: list["ImageInput"],
1118+
videos: list["VideoInput"],
1119+
audios: list["AudioInput"],
1120+
processor: "MMProcessor",
1121+
imglens: Optional[list[int]] = None,
1122+
) -> dict[str, "torch.Tensor"]:
1123+
mm_inputs = {}
1124+
if len(images) != 0:
1125+
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
1126+
images = self._regularize_images(
1127+
images,
1128+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
1129+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
1130+
)
1131+
if imglens is not None:
1132+
images = _make_batched_images(images, imglens)
1133+
1134+
image_processor_kwargs = {}
1135+
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
1136+
1137+
if len(videos) != 0:
1138+
video_processor: BaseImageProcessor = getattr(
1139+
processor, "video_processor", getattr(processor, "omni_processor", None)
1140+
)
1141+
videos = self._regularize_videos(
1142+
videos,
1143+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
1144+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
1145+
video_fps=getattr(processor, "video_fps", 2.0),
1146+
video_maxlen=getattr(processor, "video_maxlen", 128),
1147+
)
1148+
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
1149+
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
1150+
fps = [2.0] * len(videos) # FIXME hardcode
1151+
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
1152+
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
1153+
1154+
else:
1155+
raise NotImplementedError
1156+
1157+
if len(audios) != 0:
1158+
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
1159+
audios = self._regularize_audios(
1160+
audios,
1161+
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
1162+
)
1163+
mm_inputs.update(
1164+
feature_extractor(
1165+
audios,
1166+
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
1167+
return_attention_mask=True,
1168+
padding="max_length",
1169+
return_tensors="pt",
1170+
)
1171+
)
1172+
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
1173+
1174+
return mm_inputs
1175+
1176+
@override
1177+
def process_messages(
1178+
self,
1179+
messages: list[dict[str, str]],
1180+
images: list["ImageInput"],
1181+
videos: list["VideoInput"],
1182+
audios: list["AudioInput"],
1183+
processor: Optional["MMProcessor"],
1184+
) -> list[dict[str, str]]:
1185+
self._validate_input(processor, images, videos, audios)
1186+
messages = deepcopy(messages)
1187+
if self.expand_mm_tokens:
1188+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1189+
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
1190+
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
1191+
1192+
# get length or size from mm_inputs
1193+
if "feature_attention_mask" in mm_inputs:
1194+
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
1195+
audio_lengths = (input_lengths - 2) // 2 + 1
1196+
if mm_inputs.get("image_grid_thw", None) is not None:
1197+
image_grid_thw = mm_inputs["image_grid_thw"]
1198+
merge_length = processor.omni_processor.merge_size**2
1199+
if mm_inputs.get("video_grid_thw", None) is not None:
1200+
video_grid_thw = mm_inputs["video_grid_thw"]
1201+
merge_length = processor.omni_processor.merge_size**2
1202+
1203+
if use_audio_in_video:
1204+
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
1205+
assert mm_inputs.get("video_grid_thw", None) is not None, (
1206+
"video_grid_thw should be exist when use_audio_in_video is `True`"
1207+
)
1208+
positions_list = []
1209+
for i, message in enumerate(messages): # get multimodal index when use_audio
1210+
positions = []
1211+
for special_token in [self.audio_token, self.image_token, self.video_token]:
1212+
start = 0
1213+
while True:
1214+
pos = message[i].find(special_token, start)
1215+
if pos == -1:
1216+
break
1217+
positions.append((pos, special_token))
1218+
start = pos + len(special_token)
1219+
positions_list.append(positions.sort(key=lambda x: x[0]))
1220+
1221+
for message in messages:
1222+
content = message["content"]
1223+
# separate with audio-video
1224+
while IMAGE_PLACEHOLDER in content:
1225+
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
1226+
content = content.replace(
1227+
IMAGE_PLACEHOLDER,
1228+
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1229+
1,
1230+
)
1231+
num_image_tokens += 1
1232+
1233+
if not use_audio_in_video:
1234+
while AUDIO_PLACEHOLDER in content:
1235+
audio_token_replace_length = audio_lengths[num_audio_tokens]
1236+
content = content.replace(
1237+
AUDIO_PLACEHOLDER,
1238+
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1239+
1,
1240+
)
1241+
num_audio_tokens += 1
1242+
# TODO handle video_input and use_audio_in_video
1243+
while VIDEO_PLACEHOLDER in content:
1244+
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
1245+
content = content.replace(
1246+
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
1247+
)
1248+
num_video_tokens += 1
1249+
else: # if use the audio of video # deal video token and audio token togather
1250+
while VIDEO_PLACEHOLDER in content:
1251+
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
1252+
video_t_index = (
1253+
torch.arange(video_grid_thw[num_video_tokens][0])
1254+
.view(-1, 1, 1)
1255+
.expand(
1256+
-1,
1257+
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
1258+
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
1259+
)
1260+
.flatten()
1261+
* mm_inputs["video_second_per_grid"][num_video_tokens]
1262+
* 25 # FIXME hardcode of position_id_per_seconds=25
1263+
).long()
1264+
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
1265+
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
1266+
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
1267+
placeholder_string = ""
1268+
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
1269+
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
1270+
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
1271+
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
1272+
if video_chunk_index is not None:
1273+
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
1274+
if audio_chunk_index is not None:
1275+
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
1276+
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
1277+
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
1278+
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
1279+
num_audio_tokens += 1
1280+
num_video_tokens += 1
1281+
message["content"] = content
1282+
1283+
if len(audios) != num_audio_tokens:
1284+
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
1285+
if len(images) != num_image_tokens:
1286+
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
1287+
if len(videos) != num_video_tokens:
1288+
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
1289+
1290+
return messages
1291+
1292+
11071293
@dataclass
11081294
class Qwen2VLPlugin(BasePlugin):
11091295
@override
@@ -1328,6 +1514,7 @@ def process_messages(
13281514
"paligemma": PaliGemmaPlugin,
13291515
"pixtral": PixtralPlugin,
13301516
"qwen2_audio": Qwen2AudioPlugin,
1517+
"qwen2_omni": Qwen2OmniPlugin,
13311518
"qwen2_vl": Qwen2VLPlugin,
13321519
"video_llava": VideoLlavaPlugin,
13331520
}

src/llamafactory/data/template.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,24 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
13671367
)
13681368

13691369

1370+
# copied from qwen template
1371+
register_template(
1372+
name="qwen2_omni",
1373+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1374+
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1375+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1376+
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1377+
format_observation=StringFormatter(
1378+
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1379+
),
1380+
format_tools=ToolFormatter(tool_format="qwen"),
1381+
default_system="You are a helpful assistant.",
1382+
stop_words=["<|im_end|>"],
1383+
mm_plugin=get_mm_plugin(
1384+
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
1385+
),
1386+
)
1387+
13701388
# copied from qwen template
13711389
register_template(
13721390
name="qwen2_vl",

src/llamafactory/extras/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,6 +2270,18 @@ def register_model_group(
22702270
)
22712271

22722272

2273+
register_model_group(
2274+
models={
2275+
"Qwen2.5-Omni-7B": {
2276+
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
2277+
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
2278+
}
2279+
},
2280+
template="qwen2_omni",
2281+
multimodal=True,
2282+
)
2283+
2284+
22732285
register_model_group(
22742286
models={
22752287
"Qwen2-VL-2B": {

src/llamafactory/hparams/model_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ class ProcessorArguments:
222222
default=False,
223223
metadata={"help": "Use pan and scan to process image for gemma3."},
224224
)
225+
use_audio_in_video: bool = field(
226+
default=False,
227+
metadata={"help": "Whether or not to use audio in video inputs."},
228+
)
225229
video_max_pixels: int = field(
226230
default=256 * 256,
227231
metadata={"help": "The maximum number of pixels of video inputs."},

0 commit comments

Comments
 (0)