-
Notifications
You must be signed in to change notification settings - Fork 269
Added custom model inference. #437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice, exactly what I had in mind!
No big comments on the main PR, but
- you could add your model class in examples.
- you need to update the doc pages to explain how this works
- it would be good to add a small test to our suite for this feature
I'll try to run it this afternoon and if all goes well and you update the doc, we'll be good to go!
Hahaha please also provide an explicit requirements files :) |
The explicit requirement file is only needed for the google translate example, right? Where should I add that? |
google_translate_model_requirements.txt for now, next to the py file |
Great, fixed the things. @clefourrier ready for review again. |
Hi @JoelNiklaus ! Great PR, howveer, just tried it and it does not seem to work. When running:
deps:
|
Hmm, would you mind trying an environment with the requirements in examples/custom_models/google-translate-requirements-freeze.txt? |
@NathanHB I added another custom model example at examples/custom_models/local_mt_model.py |
@clefourrier @NathanHB Would you mind reviewing again? It should work better now. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Enables the evaluation of any system in the user's control. Fixes [Issue 430](#430). Try with ``` python -m lighteval custom google-translate /path/to/google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 ``` google_translate_model.py ``` import logging from typing import Optional from tqdm import tqdm from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo from lighteval.models.model_output import ( GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ) from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, ) logger = logging.getLogger(__name__) class GoogleTranslateClient(LightevalModel): def __init__(self, config, env_config) -> None: self.model = config.model self.model_definition_file_path = config.model_definition_file_path self.model_info = ModelInfo( model_name=config.model, model_sha="", model_dtype=None, model_size="", ) self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility import httpcore # Needed to fix some googletrans bug # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') from googletrans import Translator self.translator = Translator() def greedy_until( self, requests: list[GreedyUntilRequest], override_bs: Optional[int] = None, ) -> list[GenerativeResponse]: """ Generates responses using a greedy decoding strategy until certain ending conditions are met. Args: requests (list[Request]): list of requests containing the context and ending conditions. disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: list[GenerativeResponse]: list of generated responses. """ for request in requests: request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) results = [] for _ in tqdm( dataset.splits_start_end_iterator(), total=dataset.num_dataset_splits, desc="Splits", position=0, disable=False, # self.disable_tqdm, ): for r in tqdm(dataset, desc="Batch", position=1, disable=False): context = r.context.replace("French phrase: ", "") # TODO: Get src and dest from request translation = self.translator.translate(context, src='fr', dest='de') result = translation.text cur_response = GenerativeResponse( result=result, logits=None, generated_tokens=[], input_tokens=[], ) results.append(cur_response) return dataset.get_original_order(results) @Property def tokenizer(self): return self._tokenizer def tok_encode(self, text: str): return self.tokenizer.encode(text) @Property def add_special_tokens(self) -> bool: return False @Property def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 def loglikelihood( self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError def loglikelihood_rolling( self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError def loglikelihood_single_token( self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodSingleTokenResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError ```
Enables the evaluation of any system in the user's control. Fixes Issue 430.
Try with
google_translate_model.py