Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,33 @@
Processor class for LLaVa-NeXT-Video.
"""

from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


if TYPE_CHECKING:
pass

logger = logging.get_logger(__name__)


class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class LlavaNextVideoProcessor(ProcessorMixin):
r"""
Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and
Expand Down Expand Up @@ -102,13 +111,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny concern about squishing in audio before the videos, but I don't think any user passes videos as positional arg so maybe we are oke

I don't want to add more complexity by trying to validate the order of audio and video now, so let's leave as is and just take this noted

videos: VideoInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs: Unpack[LlavaNextVideoProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -130,19 +137,6 @@ def __call__(
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand All @@ -160,13 +154,21 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
LlavaNextVideoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

if videos is not None:
videos_inputs = self.video_processor(videos, return_tensors=return_tensors)
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}

Expand Down Expand Up @@ -212,13 +214,7 @@ def __call__(
prompt_strings.append(sample)
text = prompt_strings

text_inputs = self.tokenizer(
text,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})

# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
},
"image_kwargs": {},
"video_kwargs": {},
"videos_kwargs": {},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

}


Expand Down
6 changes: 3 additions & 3 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
processing_kwargs.setdefault("padding", True)
model_inputs = self.processor(
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
).to(dtype=self.torch_dtype)
model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **processing_kwargs).to(
dtype=self.torch_dtype
)
Comment on lines +348 to +350
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I right that we don't need legacy=False anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up keeping it because I had put v5.0.0 as the deprecation version


model_inputs["text"] = inputs_text

Expand Down
166 changes: 166 additions & 0 deletions tests/models/llava_next_video/test_processor_llava_next_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
import tempfile
import unittest

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor

if is_torch_available:
import torch


@require_vision
class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik ProcessorTesterMixin doesn't test videos_kwargs yet. I think we need to add video tests to make sure that llava-next-video processor works as expected

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! I added tests for video_kwargs, very similar to those on images_kwargs :)

processor_class = LlavaNextVideoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = LlavaNextImageProcessor()
video_processor = LlavaNextVideoImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
processor_kwargs = self.prepare_processor_dict()

processor = LlavaNextVideoProcessor(
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor

def prepare_processor_dict(self):
return {
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"num_additional_image_tokens": 6,
"patch_size": 4,
"vision_feature_select_strategy": "default",
}

def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
# chat_tempalate are tested as a separate test because they are saved in separate files
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())

# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]

formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)

@require_av
def test_chat_template_dict(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]

formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 4863, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)

out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])

# add image URL for return dict
messages[0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
out_dict_with_video = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])

@require_torch
@require_av
def test_chat_template_dict_torch(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]

out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
Loading