Skip to content

Adds multimodal support and MMMU pro #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
409b0c0
init
NathanHB Apr 15, 2025
ee334c5
init
NathanHB Apr 15, 2025
e988f6f
init
NathanHB Apr 15, 2025
5fddc82
Naive implementation
qubvel Apr 21, 2025
7ce9c97
Fix choices + change metric
qubvel Apr 22, 2025
e08731a
refactor prompt function
qubvel Apr 22, 2025
8d4543b
style
qubvel Apr 22, 2025
05df4b6
FIx typing
qubvel May 6, 2025
16a9e97
Merge branch 'main' into nathan-adds-multimodal
qubvel May 6, 2025
de60add
Update max length
qubvel May 6, 2025
5fd52f5
Remove docs
qubvel May 6, 2025
10b4e0b
Update auto processor
qubvel May 6, 2025
bc7610d
add quantization config, transformers config
qubvel May 6, 2025
49e4986
Update generation size
qubvel May 7, 2025
75c900c
Add batching
qubvel May 7, 2025
4e5fdd3
Style
qubvel May 7, 2025
d1ae8b7
Add images to requests
qubvel May 7, 2025
f855158
nit
qubvel May 7, 2025
641819e
nit
qubvel May 7, 2025
aa0acb7
Clean up a bit
qubvel May 7, 2025
56f962b
nit
qubvel May 7, 2025
8e99388
Fix batch size
qubvel May 7, 2025
418840d
Add images for Doc class
qubvel May 7, 2025
e35db98
clean-up prompt manager
qubvel May 7, 2025
57c18f7
Style
qubvel May 7, 2025
7cd35c2
Style
qubvel May 7, 2025
e13cac9
Clean up prompt manager
qubvel May 7, 2025
fa18ec2
Add dtype
qubvel May 7, 2025
c59e5af
Update prompt function
qubvel May 7, 2025
8f31f1b
Refactor to pass ruff check
qubvel May 7, 2025
3675066
fix the CI
NathanHB May 12, 2025
30e22ab
fix the CI
NathanHB May 12, 2025
924bf13
Fit typing
qubvel May 12, 2025
b909259
Fix system content
qubvel May 12, 2025
665474a
Split to vision and standard tasks
qubvel May 13, 2025
1a73dd0
Data parallel
qubvel May 13, 2025
b618af7
Clean up config docs, tokenizer -> processor
qubvel May 13, 2025
79e222d
Add fast image processor option
qubvel May 13, 2025
bd2c595
Fix style
qubvel May 13, 2025
831f95e
commit
NathanHB May 19, 2025
80568e7
commit
NathanHB May 19, 2025
9fb75a6
commit
NathanHB May 19, 2025
62165a8
commit
NathanHB May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig
from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.transformers.vlm_transformers import VLMTransformersModel
from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig
from lighteval.utils.imports import (
NO_LITELLM_ERROR_MSG,
Expand Down Expand Up @@ -163,7 +164,7 @@ def load_model_with_accelerate_or_default(
model = VLLMModel(config=config)
return model
else:
model = TransformersModel(config=config)
model = VLMTransformersModel(config=config)

return model

Expand Down
285 changes: 285 additions & 0 deletions src/lighteval/models/transformers/vlm_transformers.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need working will mainly be here, first is to have the greedy untill function working

Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
from typing import Union

import torch
from pydantic import PositiveInt
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
ProcessorMixin,
)

from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.models.utils import ModelConfig, _get_model_sha, _simplify_name
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodSingleTokenRequest,
)
from lighteval.utils.imports import (
is_accelerate_available,
)


logger = logging.getLogger(__name__)


if is_accelerate_available():
from datetime import timedelta

from accelerate import Accelerator, InitProcessGroupKwargs


class VLMTransformersModelConfig(ModelConfig):
"""
Base configuration class for models.

Attributes:
model_name (str):
HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
accelerator (Accelerator): accelerator to use for model training.
tokenizer (Optional[str]): HuggingFace Hub tokenizer ID that will be
used for tokenization.
multichoice_continuations_start_space (Optional[bool]): Whether to add a
space at the start of each continuation in multichoice generation.
For example, context: "What is the capital of France?" and choices: "Paris", "London".
Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London".
True adds a space, False strips a space, None does nothing
pairwise_tokenization (bool): Whether to tokenize the context and continuation as separately or together.
subfolder (Optional[str]): The subfolder within the model repository.
revision (str): The revision of the model.
batch_size (int): The batch size for model training.
max_gen_toks (Optional[int]): The maximum number of tokens to generate.
max_length (Optional[int]): The maximum length of the generated output.
add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences.
If `None`, the default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
model_parallel (bool, optional, defaults to None):
True/False: force to use or not the `accelerate` library to load a large
model across multiple devices.
Default: None which corresponds to comparing the number of processes with
the number of GPUs. If it's smaller => model-parallelism, else not.
dtype (Union[str, torch.dtype], optional, defaults to None):):
Converts the model weights to `dtype`, if specified. Strings get
converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`).
Use `dtype="auto"` to derive the type from the model's weights.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model, manually provided to load a normally floating point
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
generation_parameters (GenerationParameters): Range of parameters which will affect the generation.
generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation)

Methods:
__post_init__(): Performs post-initialization checks on the configuration.
_init_configs(model_name, env_config): Initializes the model configuration.
init_configs(env_config): Initializes the model configuration using the environment configuration.
get_model_sha(): Retrieves the SHA of the model.

"""

model_name: str
tokenizer: str | None = None
subfolder: str | None = None
revision: str = "main"
batch_size: PositiveInt | None = None
generation_size: PositiveInt = 256
max_length: PositiveInt | None = None
add_special_tokens: bool = True
model_parallel: bool | None = None
dtype: str | None = None
device: Union[int, str] = "cuda"
trust_remote_code: bool = False
use_chat_template: bool = False
compile: bool = False
pairwise_tokenization: bool = False
device_map: str | None = None

def get_model_sha(self):
return _get_model_sha(repo_id=self.model_name, revision=self.revision)


class VLMTransformersModel(LightevalModel):
def __init__(
self,
config: VLMTransformersModelConfig,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self.config = config
self.accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
self._device = self.accelerator.device
self.use_chat_template = config.use_chat_template
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space
self._add_special_tokens = config.add_special_tokens or False
self.pairwise_tokenization = config.pairwise_tokenization
self.batch_size = config.batch_size
self.transformers_config = config.get_transformers_config()

self.model_sha = config.get_model_sha()
self._max_length = self._init_max_length()
self._tokenizer = self._create_auto_tokenizer()
self.model = self._create_auto_model()

# We are in DP (and launch the script with `accelerate launch`)
if config.model_parallel is False and self.config.dtype not in ["4bit", "8bit"]:
logger.info(f"Using Data Parallelism, putting model on device {self._device}")
self.model = self.model.to(self._device)
if config.compile:
try:
logger.info("Compiling the model")
self.model.model.compile()
except AttributeError as e:
logger.warning("Could not compile the model because: ", e)

self.model_name = _simplify_name(config.model_name)

self.generation_config_dict = config.generation_parameters.to_transformers_dict()

self.model_info = ModelInfo(
model_name=self.config.model_name,
model_sha=self.model_sha,
model_dtype=config.dtype,
)

@property
def tokenizer(self):
return self._tokenizer

@property
def add_special_tokens(self):
return self._add_special_tokens

@property
def max_length(self) -> int:
return self._max_length

@property
def device(self) -> Union[int, str, torch.device]:
return self._device

@property
def disable_tqdm(self) -> bool:
disable_tqdm = False
if self.accelerator:
disable_tqdm = bool(not self.accelerator.is_main_process)
return disable_tqdm

def _create_auto_model(self) -> AutoModelForVision2Seq:
subfolder = self.config.subfolder
revision = self.config.revision + (f"/{subfolder}" if subfolder is not None else "")

model = AutoModelForVision2Seq.from_pretrained(
self.config.model_name,
revision=revision,
device_map=self.config.device_map,
torch_dtype=self.config.dtype,
trust_remote_code=self.config.trust_remote_code,
)
model.eval()
torch.set_grad_enabled(False)

if self.config.compile:
try:
logger.info("Compiling the model")
model.compile()
except AttributeError as e:
logger.warning("Could not compile the model because: ", e)

return model

def _create_auto_tokenizer(
self,
) -> ProcessorMixin:
"""
Create a Hugging Face AutoTokenizer for language model.

Returns:
transformers.PreTrainedTokenizer: The created tokenizer.
"""
tokenizer_name = self.config.tokenizer or self.config.model_name
subfolder = self.config.subfolder
revision = self.config.revision + (f"/{subfolder}" if subfolder is not None else "")

tokenizer = AutoProcessor.from_pretrained(
tokenizer_name,
revision=revision,
trust_remote_code=self.config.trust_remote_code,
padding_side="left",
truncation_side="left",
)

return tokenizer

def _init_max_length(self) -> int:
"""
Returns:
int: Max length to use depending on the available args and config
"""
raise NotImplementedError()

def greedy_until(
self,
requests: list[GreedyUntilRequest],
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerativeResponse]: list of generated responses.
"""
raise NotImplementedError()

def loglikelihood(
self,
requests: list[LoglikelihoodRequest],
) -> list[LoglikelihoodResponse]:
raise NotImplementedError()

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest]
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.

Args:
requests (list[Tuple[str, dict]]): _description_

Returns:
list[Tuple[float, bool]]: _description_
"""
raise NotImplementedError()
57 changes: 57 additions & 0 deletions src/lighteval/tasks/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@
# fmt: on


def mmmu(line, task_name: str = None):
import base64
from io import BytesIO

standard = "Answer with the option letter from the given choices directly."

def replace_images_tokens(input_string):
image_order = [int(num) for num in re.findall(r"<image\s+(\d+)>", input_string)]
input_string = re.sub(r"<image\s+\d+>", "<image>", input_string)
return input_string, image_order

def parse_options(options):
option_letters = [chr(ord("A") + i) for i in range(len(options))]
choices_str = "\n".join(
[f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)]
)
return choices_str

def construct_prompt(doc):
question = doc["question"]
parsed_options = parse_options(ast.literal_eval(str(doc["options"])))
question = f"{question}\n{parsed_options}\n{standard}"
return question

def origin_mmmu_doc_to_visual(doc, image_order):
visual = []
for idx in image_order:
visual.append(doc[f"image_{idx}"])
return visual

def mmmu_doc_to_text(doc):
question = construct_prompt(doc)
return replace_images_tokens(question)

def encode_pil_image(pil_image):
# Create a byte stream object
buffered = BytesIO()
# Save the PIL image object as a byte stream in PNG format
pil_image.save(buffered, format="PNG")
# Get the byte stream data and perform Base64 encoding
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str

prompt, image_order = mmmu_doc_to_text(line)
images = origin_mmmu_doc_to_visual(line, image_order)

images = [encode_pil_image(image) for image in images]

return Doc(
task_name=task_name,
query=prompt,
choices=line["options"],
gold_index=string.ascii_uppercase.index(line["answer"]),
specific={"images": images, "id": line["id"]},
)


def aime_prompt_fn(line, task_name: str = None):
# Prompt template adapted from
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
Expand Down
16 changes: 16 additions & 0 deletions src/lighteval/tasks/default_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
from lighteval.tasks.lighteval_task import LightevalTaskConfig


mmmu_pro = LightevalTaskConfig(
name="mmmu_pro",
suite=["lighteval"],
prompt_function=prompt.mmmu,
hf_repo="MMMU/MMMU_pro",
hf_subset="standard (4 options)",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=1,
metric=[Metrics.exact_match],
stop_sequence=None,
trust_dataset=True,
version=0,
)
abstract_narrative_understanding_bigbench = LightevalTaskConfig(
name="abstract_narrative_understanding",
suite=["bigbench", "bigbench_json"],
Expand Down
Loading
Loading