Skip to content

fix Glm4v batch videos forward #39172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/transformers/models/glm4v/modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1081,7 +1082,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1125,7 +1125,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down Expand Up @@ -1174,7 +1178,13 @@ def get_video_features(
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds
Expand Down
50 changes: 42 additions & 8 deletions src/transformers/models/glm4v/modular_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1093,7 +1094,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1137,7 +1137,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down Expand Up @@ -1173,6 +1177,30 @@ def get_rope_index(

return position_ids, mrope_position_deltas

def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.

Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds

@auto_docstring
@can_return_tuple
def forward(
Expand Down Expand Up @@ -1664,32 +1692,38 @@ def __call__(
video_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_frames = len(video_grid_thw)
num_frames = video_grid_thw[video_index][0]
video_structure = ""

if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0]
else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps

unique_timestamps = []
for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx])

selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)

for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure

text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

video_index += 1

for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
Expand Down
18 changes: 12 additions & 6 deletions src/transformers/models/glm4v/processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,32 +167,38 @@ def __call__(
video_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_frames = len(video_grid_thw)
num_frames = video_grid_thw[video_index][0]
video_structure = ""

if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0]
else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps

unique_timestamps = []
for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx])

selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)

for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure

text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

video_index += 1

for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/glm4v/video_processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,6 @@ def _preprocess(
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
pixel_values_videos = torch.cat(processed_videos, dim=0)
video_grid_thw = torch.tensor(processed_grids)
total_frames = video_grid_thw[0][0].item()
h = video_grid_thw[0][1].item()
w = video_grid_thw[0][2].item()
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also would need to pad timestamps as otherwise it will fail when different number of frames are sampled per video. We've been discussing it internally with @zRzRzRzRzRzRzR , not sure though if he has any PR yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timestamps is not good to return here, can we return it like qwen2_5vl does ?

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure if this is equivalent to what GLM4V does because in GLM we want to add timestamps per frame in the prompt. We talked with this internally and decided that padding/unpadding can work, as the timestamps are used in internal processing only. So we can pad on the right, and strip off pad values in processing.py

data = {
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
Expand Down
60 changes: 59 additions & 1 deletion tests/models/glm4v/test_modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch GLM-4.1V model."""

import copy
import gc
import unittest

Expand Down Expand Up @@ -236,7 +237,26 @@ def test_multi_gpu_data_parallel_forward(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

# RoPE index doesn't match when using embeddings
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))

input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["image_grid_thw"]

wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)[0]

def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down Expand Up @@ -350,6 +370,44 @@ def test_small_model_integration_test_batch(self):
EXPECTED_DECODED_TEXT,
)

@slow
def test_small_model_integration_test_with_video(self):
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking", max_image_size={"longest_edge": 50176})
model = Glm4vForConditionalGeneration.from_pretrained(
"THUDM/GLM-4.1V-9B-Thinking", torch_dtype=torch.float16, device_map="auto"
)
questions = ["Describe this video."] * 2
video_urls = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"
] * 2
messages = [
[
{
"role": "user",
"content": [
{
"type": "video",
"video": video_url,
},
{"type": "text", "text": question},
],
}
]
for question, video_url in zip(questions, video_urls)
]
inputs = processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami",
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami"
] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
def test_small_model_integration_test_expand(self):
model = Glm4vForConditionalGeneration.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion tests/models/glm4v/test_video_processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_call_pytorch(self):
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)

@unittest.skip("Skip for now, the test needs adjustment fo GLM-4.1V")
@unittest.skip("Skip for now, the test needs adjustment for GLM-4.1V")
def test_call_numpy_4_channels(self):
for video_processing_class in self.video_processor_list:
# Test that can process videos which have an arbitrary number of channels
Expand Down