diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py index 840061788..6c37c0e68 100644 --- a/src/lighteval/models/litellm_model.py +++ b/src/lighteval/models/litellm_model.py @@ -22,6 +22,7 @@ import logging import os +import re import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -93,12 +94,17 @@ def __init__(self, config, env_config) -> None: litellm.drop_params = True litellm.set_verbose = False + def is_reasoning_model(self): + return "o1" in self.model or "o3" in self.model or "R1" in self.model + def _prepare_stop_sequence(self, stop_sequence): """Prepare and validate stop sequence.""" if self.provider == "anthropic": # Filter out whitespace-only stop sequences if stop_sequence: stop_sequence = [s for s in stop_sequence if s and s.strip()] + if not stop_sequence: # If empty after filtering + stop_sequence = ["\n"] return stop_sequence def _prepare_max_new_tokens(self, max_new_tokens): @@ -106,7 +112,7 @@ def _prepare_max_new_tokens(self, max_new_tokens): if not max_new_tokens or max_new_tokens <= 0: return None - if "o1" in self.model: + if self.is_reasoning_model(): # We need to allow more tokens to include reasoning tokens max_new_tokens = min(max_new_tokens * 10, 32000) return max_new_tokens @@ -132,8 +138,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se "n": num_samples, "caching": True, } - if "o1" in self.model: - logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.") + if self.is_reasoning_model(): + logger.warning("Reasoning models do not support temperature, top_p, stop sequence. Disabling.") else: kwargs["temperature"] = self.TEMPERATURE kwargs["top_p"] = self.TOP_P @@ -142,10 +148,17 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se response = litellm.completion(**kwargs) # If response is empty, retry without caching (maybe the error is recoverable and solved with a retry) - if response.choices[0].message.content is None: + content = response.choices[0].message.content + if not content: kwargs["caching"] = False logger.info("Response is empty, retrying without caching") response = litellm.completion(**kwargs) + + if content is not None and "" in content: + logger.debug(f"Removing tags from response: {content}") + response.choices[0].message.content = re.sub( + r".*?", "", content, flags=re.DOTALL + ).strip() return response except litellm.BadRequestError as e: if "message" in e.__dict__: