Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,16 @@ def __call__(
raise ValueError(
"If a batch of text is provided, there should be either no images or at least one image per sample"
)
if sum(n_images_in_images) != sum(n_images_in_text):
if sum(n_images_in_text) > 0 and n_images_in_images != n_images_in_text:
if images is None:
raise ValueError("No image were provided, but there are image tokens in the prompt")
else:
add_message = ""
if sum(n_images_in_images) == sum(n_images_in_text):
add_message = "Make sure to pass your images as a nested list, where each sub-list holds images per batch"
raise ValueError(
f"The number of image token ({sum(n_images_in_text)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
f"The number of image tokens in each text ({n_images_in_text}) should be the same as the "
f"number of provided images per batch ({n_images_in_images}). {add_message}"
)

if images is not None:
Expand Down
34 changes: 34 additions & 0 deletions tests/models/mllama/test_processor_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ def test_process_interleaved_images_prompts_image_error(self):
with self.assertRaises(ValueError):
processor(text=text, images=None, padding=True)

# see https://github.com/huggingface/transformers/pull/35934
images = [self.image1, self.image2]
with self.assertRaises(ValueError):
processor(text=text, images=None, padding=True)

# Override as MllamaProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
Expand All @@ -340,3 +345,32 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None):
return ["lower newer <|image|>", "<|image|> upper older longer string"] + ["<|image|> lower newer"] * (
batch_size - 2
)

def test_unstructured_kwargs_batched(self):
# Overriden because Mllama expects images in nested format. For 2 images it can't infer
# the correct nesting, so we better throw an error
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)

input_str = self.prepare_text_inputs(batch_size=2)
image_input = self.prepare_image_inputs(batch_size=2)
image_input = [[image_input[0]], [image_input[1]]]
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
do_rescale=True,
rescale_factor=-1,
padding="longest",
max_length=76,
)

self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
self.assertTrue(
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
and len(inputs[self.text_input_name][1]) < 76
)