Skip to content

Commit 520b9dc

Browse files
authored
fix Glm4v batch videos forward (#39172)
* changes for video * update modular * change get_video_features * update video token replacement * update modular * add test and fix typo * lint * fix order * lint * fix * remove dependency * lint * lint * remove todo * resize video for test * lint.. * fix test * new a processor for video_test * fix test
1 parent bc161d5 commit 520b9dc

File tree

6 files changed

+127
-23
lines changed

6 files changed

+127
-23
lines changed

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ def get_rope_index(
10521052
device=input_ids.device,
10531053
)
10541054
image_index, video_index = 0, 0
1055+
video_group_index = 0
10551056
attention_mask = attention_mask.to(total_input_ids.device)
10561057
for i, input_ids in enumerate(total_input_ids):
10571058
input_ids = input_ids[attention_mask[i] == 1]
@@ -1081,7 +1082,6 @@ def get_rope_index(
10811082

10821083
llm_pos_ids_list = []
10831084
video_frame_num = 1
1084-
10851085
for modality_type, start_idx, end_idx in input_type_group:
10861086
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
10871087

@@ -1125,7 +1125,11 @@ def get_rope_index(
11251125
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
11261126
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
11271127

1128-
video_index += 1
1128+
video_group_index += 1
1129+
1130+
if video_group_index >= video_grid_thw[video_index][0]:
1131+
video_index += 1
1132+
video_group_index = 0
11291133

11301134
video_frame_num += 1
11311135

@@ -1174,7 +1178,13 @@ def get_video_features(
11741178
The temporal, height and width of feature shape of each video in LLM.
11751179
"""
11761180
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1177-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1181+
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
1182+
temp_frames_hw = []
1183+
for t, h, w in video_grid_thw:
1184+
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
1185+
temp_frames_hw.append(repeated_row)
1186+
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
1187+
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
11781188
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
11791189
video_embeds = torch.split(video_embeds, split_sizes)
11801190
return video_embeds

src/transformers/models/glm4v/modular_glm4v.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ def get_rope_index(
10641064
device=input_ids.device,
10651065
)
10661066
image_index, video_index = 0, 0
1067+
video_group_index = 0
10671068
attention_mask = attention_mask.to(total_input_ids.device)
10681069
for i, input_ids in enumerate(total_input_ids):
10691070
input_ids = input_ids[attention_mask[i] == 1]
@@ -1093,7 +1094,6 @@ def get_rope_index(
10931094

10941095
llm_pos_ids_list = []
10951096
video_frame_num = 1
1096-
10971097
for modality_type, start_idx, end_idx in input_type_group:
10981098
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
10991099

@@ -1137,7 +1137,11 @@ def get_rope_index(
11371137
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
11381138
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
11391139

1140-
video_index += 1
1140+
video_group_index += 1
1141+
1142+
if video_group_index >= video_grid_thw[video_index][0]:
1143+
video_index += 1
1144+
video_group_index = 0
11411145

11421146
video_frame_num += 1
11431147

@@ -1173,6 +1177,30 @@ def get_rope_index(
11731177

11741178
return position_ids, mrope_position_deltas
11751179

1180+
def get_video_features(
1181+
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1182+
):
1183+
"""
1184+
Encodes videos into continuous embeddings that can be forwarded to the language model.
1185+
1186+
Args:
1187+
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1188+
The tensors corresponding to the input videos.
1189+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1190+
The temporal, height and width of feature shape of each video in LLM.
1191+
"""
1192+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1193+
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
1194+
temp_frames_hw = []
1195+
for t, h, w in video_grid_thw:
1196+
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
1197+
temp_frames_hw.append(repeated_row)
1198+
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
1199+
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
1200+
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1201+
video_embeds = torch.split(video_embeds, split_sizes)
1202+
return video_embeds
1203+
11761204
@auto_docstring
11771205
@can_return_tuple
11781206
def forward(
@@ -1664,32 +1692,38 @@ def __call__(
16641692
video_index = 0
16651693
for i in range(len(text)):
16661694
while self.video_token in text[i]:
1667-
num_frames = len(video_grid_thw)
1695+
num_frames = video_grid_thw[video_index][0]
16681696
video_structure = ""
16691697

16701698
if hasattr(timestamps, "tolist"):
16711699
timestamps_list = timestamps.tolist()[0]
16721700
else:
16731701
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
1702+
16741703
unique_timestamps = []
16751704
for idx in range(0, len(timestamps_list)):
16761705
unique_timestamps.append(timestamps_list[idx])
1706+
16771707
selected_timestamps = unique_timestamps[:num_frames]
16781708
while len(selected_timestamps) < num_frames:
16791709
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
1710+
16801711
for frame_idx in range(num_frames):
16811712
timestamp_sec = selected_timestamps[frame_idx]
16821713
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
16831714
video_structure += frame_structure
1715+
16841716
text[i] = text[i].replace(self.video_token, video_structure, 1)
1717+
num_image_tokens = (
1718+
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
1719+
)
1720+
for frame_idx in range(num_frames):
1721+
if self.image_token in text[i]:
1722+
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
1723+
16851724
video_index += 1
16861725

1687-
for frame_idx in range(len(video_grid_thw)):
1688-
if self.image_token in text[i]:
1689-
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
1690-
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
16911726
text[i] = text[i].replace("<|placeholder|>", self.image_token)
1692-
16931727
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
16941728
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
16951729
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])

src/transformers/models/glm4v/processing_glm4v.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,32 +167,38 @@ def __call__(
167167
video_index = 0
168168
for i in range(len(text)):
169169
while self.video_token in text[i]:
170-
num_frames = len(video_grid_thw)
170+
num_frames = video_grid_thw[video_index][0]
171171
video_structure = ""
172172

173173
if hasattr(timestamps, "tolist"):
174174
timestamps_list = timestamps.tolist()[0]
175175
else:
176176
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
177+
177178
unique_timestamps = []
178179
for idx in range(0, len(timestamps_list)):
179180
unique_timestamps.append(timestamps_list[idx])
181+
180182
selected_timestamps = unique_timestamps[:num_frames]
181183
while len(selected_timestamps) < num_frames:
182184
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
185+
183186
for frame_idx in range(num_frames):
184187
timestamp_sec = selected_timestamps[frame_idx]
185188
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
186189
video_structure += frame_structure
190+
187191
text[i] = text[i].replace(self.video_token, video_structure, 1)
192+
num_image_tokens = (
193+
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
194+
)
195+
for frame_idx in range(num_frames):
196+
if self.image_token in text[i]:
197+
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
198+
188199
video_index += 1
189200

190-
for frame_idx in range(len(video_grid_thw)):
191-
if self.image_token in text[i]:
192-
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
193-
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
194201
text[i] = text[i].replace("<|placeholder|>", self.image_token)
195-
196202
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
197203
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
198204
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])

src/transformers/models/glm4v/video_processing_glm4v.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,6 @@ def _preprocess(
249249
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
250250
pixel_values_videos = torch.cat(processed_videos, dim=0)
251251
video_grid_thw = torch.tensor(processed_grids)
252-
total_frames = video_grid_thw[0][0].item()
253-
h = video_grid_thw[0][1].item()
254-
w = video_grid_thw[0][2].item()
255-
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
256252
data = {
257253
"pixel_values_videos": pixel_values_videos,
258254
"video_grid_thw": video_grid_thw,

tests/models/glm4v/test_modeling_glm4v.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Testing suite for the PyTorch GLM-4.1V model."""
1515

16+
import copy
1617
import gc
1718
import unittest
1819

@@ -236,7 +237,26 @@ def test_multi_gpu_data_parallel_forward(self):
236237
def test_generate_from_inputs_embeds_with_static_cache(self):
237238
pass
238239

239-
# RoPE index doesn't match when using embeddings
240+
def test_inputs_embeds(self):
241+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
242+
243+
for model_class in self.all_model_classes:
244+
model = model_class(config)
245+
model.to(torch_device)
246+
model.eval()
247+
248+
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
249+
250+
input_ids = inputs["input_ids"]
251+
del inputs["input_ids"]
252+
del inputs["pixel_values"]
253+
del inputs["image_grid_thw"]
254+
255+
wte = model.get_input_embeddings()
256+
inputs["inputs_embeds"] = wte(input_ids)
257+
with torch.no_grad():
258+
model(**inputs)[0]
259+
240260
def test_inputs_embeds_matches_input_ids(self):
241261
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
242262

@@ -350,6 +370,44 @@ def test_small_model_integration_test_batch(self):
350370
EXPECTED_DECODED_TEXT,
351371
)
352372

373+
@slow
374+
def test_small_model_integration_test_with_video(self):
375+
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking", max_image_size={"longest_edge": 50176})
376+
model = Glm4vForConditionalGeneration.from_pretrained(
377+
"THUDM/GLM-4.1V-9B-Thinking", torch_dtype=torch.float16, device_map="auto"
378+
)
379+
questions = ["Describe this video."] * 2
380+
video_urls = [
381+
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"
382+
] * 2
383+
messages = [
384+
[
385+
{
386+
"role": "user",
387+
"content": [
388+
{
389+
"type": "video",
390+
"video": video_url,
391+
},
392+
{"type": "text", "text": question},
393+
],
394+
}
395+
]
396+
for question, video_url in zip(questions, video_urls)
397+
]
398+
inputs = processor.apply_chat_template(
399+
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True
400+
).to(torch_device)
401+
output = model.generate(**inputs, max_new_tokens=30)
402+
EXPECTED_DECODED_TEXT = [
403+
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami",
404+
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami"
405+
] # fmt: skip
406+
self.assertEqual(
407+
processor.batch_decode(output, skip_special_tokens=True),
408+
EXPECTED_DECODED_TEXT,
409+
)
410+
353411
@slow
354412
def test_small_model_integration_test_expand(self):
355413
model = Glm4vForConditionalGeneration.from_pretrained(

tests/models/glm4v/test_video_processing_glm4v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_call_pytorch(self):
228228
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
229229
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
230230

231-
@unittest.skip("Skip for now, the test needs adjustment fo GLM-4.1V")
231+
@unittest.skip("Skip for now, the test needs adjustment for GLM-4.1V")
232232
def test_call_numpy_4_channels(self):
233233
for video_processing_class in self.video_processor_list:
234234
# Test that can process videos which have an arbitrary number of channels

0 commit comments

Comments
 (0)