Skip to content

fix Glm4v batch videos forward #39172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 10, 2025
Merged

fix Glm4v batch videos forward #39172

merged 26 commits into from
Jul 10, 2025

Conversation

Kuangdd01
Copy link
Contributor

@Kuangdd01 Kuangdd01 commented Jul 2, 2025

What does this PR do?

Fixes the issues of video_processing and get_video_features for GLM4V.

Have tested with following scripts

import torch
from transformers import AutoProcessor, Glm4vForConditionalGeneration
from PIL import Image
import numpy as np
import cv2
import os
from dataclasses import dataclass
from transformers.video_utils import VideoMetadata

def prepare_video_metadata(videos):
    video_metadata = []
    for video in videos:
        if isinstance(video, list):
            num_frames = len(video)
        elif hasattr(video, "shape"):
            if len(video.shape) == 4:  # (T, H, W, C)
                num_frames = video.shape[0]
            else:
                num_frames = 1
        else:
            num_frames = 8
            print("eeeeee")

        metadata = {
            "fps": 2,
            "duration": num_frames / 2,
            "total_frames": num_frames,
        }
        video_metadata.append(metadata)
    return video_metadata

def test_video_processing(video_path_list, num_frames=4):
    selected_frames = []
    for video_path in video_path_list:
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f"Total frames: {frame_count}")

    video_metadata = []
    for video_path in video_path_list:
        temp_frames = []
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        step = max(frame_count // num_frames, 1)
        for i in range(0, frame_count, step):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if not ret:
                continue
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_img = Image.fromarray(frame_rgb)
            temp_frames.append(pil_img)
        selected_frames.append(temp_frames)

    video_metadata = prepare_video_metadata(selected_frames)
    video_inputs = processor.video_processor(videos=selected_frames, video_metadata=video_metadata)

    questions = ["What kind of dog is this?", "Describe the background."]

    messages_batch = [
        [
            {
                "role": "user",
                "content": [
                    {"type": "video"},
                    {"type": "text", "text": question},
                ],
            }
        ]
        for question in questions
    ]

    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages_batch
    ]

    inputs_batch = processor(text=texts, videos=selected_frames, video_metadata=video_metadata, return_tensors="pt", padding=True)

    print(processor.batch_decode(inputs_batch['input_ids'])[0])
    rope_pos, deltas = model.model.get_rope_index(
        inputs_batch["input_ids"],
        None,
        inputs_batch["video_grid_thw"],
        inputs_batch["attention_mask"]
    )

    print(rope_pos.shape, "\n", deltas)

processor_name = "THUDM/GLM-4.1V-9B-Thinking"
processor = AutoProcessor.from_pretrained(processor_name)
model = Glm4vForConditionalGeneration.from_pretrained(processor_name)

if __name__ == "__main__":
    # image_path = "./data/mllm_demo_data/1.jpg"
    video_path_1 = "./data/mllm_demo_data/1.mp4"
    video_path_2 = "./data/mllm_demo_data/2.avi"

    test_video_processing([video_path_1, video_path_2])

For forward logits checking, @zRzRzRzRzRzRzR

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp cc @zRzRzRzRzRzRzR

@Kuangdd01
Copy link
Contributor Author

Failed for changing the get_video_features which is not consistent with that generated from modular. 😂

total_frames = video_grid_thw[0][0].item()
h = video_grid_thw[0][1].item()
w = video_grid_thw[0][2].item()
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also would need to pad timestamps as otherwise it will fail when different number of frames are sampled per video. We've been discussing it internally with @zRzRzRzRzRzRzR , not sure though if he has any PR yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timestamps is not good to return here, can we return it like qwen2_5vl does ?

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure if this is equivalent to what GLM4V does because in GLM we want to add timestamps per frame in the prompt. We talked with this internally and decided that padding/unpadding can work, as the timestamps are used in internal processing only. So we can pad on the right, and strip off pad values in processing.py

@Kuangdd01
Copy link
Contributor Author

😀Do I need to write more unit tests for this change?

@zucchini-nlp
Copy link
Member

😀Do I need to write more unit tests for this change?

we can add integration test with batched/unbatched video inference, similar to

@slow
def test_small_model_integration_test_batch(self):

@Kuangdd01
Copy link
Contributor Author

😀Do I need to write more unit tests for this change?

we can add integration test with batched/unbatched video inference, similar to

@slow
def test_small_model_integration_test_batch(self):

Thanks, I have tested this test_small_model_integration_test_with_video locally. BTW, should we remove @unittest.skip("Model checkpoint not yet released") in this tester?

@zucchini-nlp
Copy link
Member

Yeeah, the checkpoints were updated in #39247. Can you rebase?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple tiny comments

@Kuangdd01
Copy link
Contributor Author

Kuangdd01 commented Jul 8, 2025

Failed tests following now after run pytest tests/models/glm4v/test_modeling_glm4v.py locally

FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_generate_from_inputs_embeds_0_greedy - ValueError: You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one
FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_generate_from_inputs_embeds_1_beam_search - ValueError: You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one
FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_resize_embeddings_untied_with_deepspeed - ModuleNotFoundError: No module named 'mpi4py'
FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_resize_embeddings_untied_with_deepspeed_multi_gpu - ModuleNotFoundError: No module named 'mpi4py'
FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_resize_tokens_embeddings_with_deepspeed - ModuleNotFoundError: No module named 'mpi4py'
FAILED tests/models/glm4v/test_modeling_glm4v.py::Glm4vModelTest::test_resize_tokens_embeddings_with_deepspeed_multi_gpu - ModuleNotFoundError: No module named 'mpi4py'

@Kuangdd01 Kuangdd01 requested a review from zucchini-nlp July 8, 2025 09:37
@zucchini-nlp
Copy link
Member

run-slow: glm4v

Copy link
Contributor

github-actions bot commented Jul 9, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v']
quantizations: [] ...

@Kuangdd01
Copy link
Contributor Author

No, they don't. I see the last commit changes the self.processor, we usually don't change global var within test cases, it might cause unexpected errors. The test cases are run in random order afaik so if we change longest_edge, the subsequent tests might fail to match to expected_output

Yes, I think it fails due to the changes in this global var. After I ran test_with_video, test_with_numpy would fail. But it seems that all tests pass finally.

@zucchini-nlp
Copy link
Member

cc @ydshieh , do you know anything about test files sharing the same processor? Happy to keep the fix if you agree, but looks weird to me 👀

@zucchini-nlp zucchini-nlp added the for patch Tag issues / labels that should be included in the next patch label Jul 9, 2025
@zucchini-nlp
Copy link
Member

Not really a regression, but maybe can have in patch

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2025

cc @ydshieh , do you know anything about test files sharing the same processor? Happy to keep the fix if you agree, but looks weird to me 👀

Hi, may I have a bit of context here 🙏 . I guess you are talking about

Hmm, don't know why this test failed. Do test_modeling and test_video* share one video_processor?

But at this moment, I don't know what the issues you faced. A bit more details would be very appreciated. Thank you

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jul 9, 2025

Oh yeah, sorry @ydshieh . The issue is that a test failed in slow CI when we ran it last time. The test is tests/models/glm4v/test_video_processing_glm4v.py::Glm4vVideoProcessingTest::test_call_numpy. There are two weird things about it:

  1. It is failing only in slow CI, but not in fast one. The test isn't marked as slow so I assumed it would have been triggered already. Maybe it want' triggered then it's fine

  2. To fix the test, contributor had to change test_modeling_*.py file, specifically the self.processor in IntegrationTest. I don't know how self.processor from modeling test is being re-used in processing test, and if this is expected. As per the last comment, it fixes the issue magically. Have you encountered anything similar before?

    Yes, I think it fails due to the changes in this global var. After I ran test_with_video, test_with_numpy would fail. But it seems that all tests pass finally.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2025

Thanks. I will take a look today.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2025

@zucchini-nlp Those are class attributes, so if an instance change the values, it will affect the other places!

Do you know why we use class attributes 😭

class Glm4vVideoProcessor(BaseVideoProcessor):
    resample = PILImageResampling.BICUBIC
    size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 2 * 30000}
    max_image_size = {"longest_edge": 28 * 28 * 2 * 30000}
    image_mean = OPENAI_CLIP_MEAN
    image_std = OPENAI_CLIP_STD
    do_resize = True
    do_rescale = True
    do_normalize = True
    do_convert_rgb = True
    do_sample_frames = True
    patch_size = 14
    temporal_patch_size = 2
    max_duration = 300
    merge_size = 2
    valid_kwargs = Glm4vVideoProcessorInitKwargs
    num_frames = 16
    fps = 2

@zucchini-nlp
Copy link
Member

Do you know why we use class attributes 😭

This is copied from image processors, because video inherits from it. I believe it was done to have defaults for kwargs, and keep core processing code in one place. @yonigozlan can say more on that

I didn't see that the new test changes processor attribute, in the meanwhile we can fix and merge the PR. We can move discussion later to internal slack if needed :)

@zucchini-nlp
Copy link
Member

run-slow: glm4v

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v']
quantizations: [] ...

@Kuangdd01
Copy link
Contributor Author

Hmm, failed again

FAILED tests/models/glm4v/test_video_processing_glm4v.py::Glm4vVideoProcessingTest::test_call_numpy 
- AssertionError: Lists differ: [64, 1176] != [144, 1176]

@zucchini-nlp
Copy link
Member

Oh wait, i didn't check the changes, It's because we're still re-assigning the value of max_length. I meant to load with the new value like processor = AutoProcessor.from_pretrained("glm4V", max_image_size="{NEW_SMALL_SIZE}"). It doesn't change class attribute I think

@Kuangdd01
Copy link
Contributor Author

Sorry about that, I finally realized Yih-Dar's words. I'll fix that dangerous code. 😭

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm4v

@zucchini-nlp
Copy link
Member

run-slow: glm4v

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v']
quantizations: [] ...

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 10, 2025

Do you know why we use class attributes 😭

This is copied from image processors, because video inherits from it. I believe it was done to have defaults for kwargs, and keep core processing code in one place. @yonigozlan can say more on that

I didn't see that the new test changes processor attribute, in the meanwhile we can fix and merge the PR. We can move discussion later to internal slack if needed :)

Just want to point out that if an attribute makes sense to be as an instance attribute (i.e. we could have instances that are allowed to have different values for that attribute, then make it as class attribute is not a great idea and will sometimes cause chaos. I remember my nightmare when debugging Instructblip that has keep_fp32_.... class attribute back then.

Might be a good time to rethink of this @yonigozlan

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jul 10, 2025

Yaay, test passing, can merge now 🎉

@zucchini-nlp zucchini-nlp merged commit 520b9dc into huggingface:main Jul 10, 2025
21 checks passed
@zucchini-nlp
Copy link
Member

Just want to point out that if an attribute makes sense to be as an instance attribute (i.e. we could have instances that are allowed to have different values for that attribute, then make it as class attribute is not a great idea and will sometimes cause chaos. I remember my nightmare when debugging Instructblip that has keep_fp32_.... class attribute back then.

And now on this topic, I had some local branch where I am trying to use dataclasses to store all processor values, and thus remove iterating over TypedDict.__annotations__ for checking/setting values. I can go back and work on that if we think that it makes sense

Cyrilvallez pushed a commit that referenced this pull request Jul 11, 2025
* changes for video

* update modular

* change get_video_features

* update video token replacement

* update modular

* add test and fix typo

* lint

* fix order

* lint

* fix

* remove dependency

* lint

* lint

* remove todo

* resize video for test

* lint..

* fix test

* new a processor for video_test

* fix test
rjgleaton pushed a commit to rjgleaton/transformers that referenced this pull request Jul 17, 2025
* changes for video

* update modular

* change get_video_features

* update video token replacement

* update modular

* add test and fix typo

* lint

* fix order

* lint

* fix

* remove dependency

* lint

* lint

* remove todo

* resize video for test

* lint..

* fix test

* new a processor for video_test

* fix test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants