|
27 | 27 | from typing import Iterator, TypeAlias
|
28 | 28 | from unittest.mock import patch
|
29 | 29 |
|
| 30 | +import torch |
30 | 31 | import docker
|
31 | 32 | import docker.errors
|
32 | 33 | import pytest
|
33 | 34 | import requests
|
34 | 35 | from huggingface_hub import ChatCompletionInputMessage
|
| 36 | +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast |
35 | 37 |
|
36 | 38 | from lighteval.logging.evaluation_tracker import EvaluationTracker
|
37 | 39 | from lighteval.metrics.metrics import Metrics
|
38 | 40 | from lighteval.models.tgi_model import ModelClient as TGIModel
|
| 41 | +from lighteval.models.base_model import BaseModel |
39 | 42 | from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
|
40 | 43 | from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
|
41 | 44 | from lighteval.tasks.requests import (
|
| 45 | + LoglikelihoodRequest, |
| 46 | + LoglikelihoodRollingRequest, |
42 | 47 | Doc,
|
43 | 48 | Request,
|
44 | 49 | RequestType,
|
45 | 50 | )
|
46 |
| -from lighteval.utils.utils import EnvConfig |
47 |
| - |
48 |
| - |
49 |
| -TOKEN = os.environ.get("HF_TOKEN") |
50 |
| -CACHE_PATH = os.getenv("HF_HOME", ".") |
51 | 51 |
|
52 | 52 |
|
53 | 53 | @pytest.fixture(scope="module")
|
@@ -83,12 +83,19 @@ def tgi_model() -> Iterator[TGIModel]:
|
83 | 83 | raise RuntimeError("Couldn't setup TGI server.")
|
84 | 84 | model = TGIModel(address)
|
85 | 85 | yield model
|
86 |
| - container.stop() |
87 |
| - container.wait() |
88 |
| - container.remove() |
| 86 | + # container.stop() |
| 87 | + # container.wait() |
| 88 | + # container.remove() |
89 | 89 | model.cleanup()
|
90 | 90 |
|
91 | 91 |
|
| 92 | +@pytest.fixture(scope="module") |
| 93 | +def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]: |
| 94 | + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
| 95 | + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
| 96 | + return model, tokenizer |
| 97 | + |
| 98 | + |
92 | 99 | RequestDict: TypeAlias = dict[RequestType, list[Request]]
|
93 | 100 |
|
94 | 101 |
|
@@ -150,28 +157,62 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict:
|
150 | 157 | result[req_type].extend(doc_result[req_type])
|
151 | 158 | return result
|
152 | 159 |
|
153 |
| - def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): |
154 |
| - returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) |
| 160 | + def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): |
| 161 | + requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL] |
| 162 | + returns = tgi_model.greedy_until(requests) |
| 163 | + model, tokenizer = reference_model_tokenizer |
155 | 164 | assert len(returns) == 2
|
156 |
| - assert all(r.result is not None for r in returns) |
157 |
| - |
158 |
| - def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): |
159 |
| - returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) |
| 165 | + for req, res in zip(requests, returns): |
| 166 | + is_chat = not isinstance(req.context, str) |
| 167 | + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] |
| 168 | + ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist() |
| 169 | + continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):] |
| 170 | + assert continuation == res.result |
| 171 | + |
| 172 | + def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): |
| 173 | + requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD] |
| 174 | + returns = tgi_model.loglikelihood(requests) |
| 175 | + model, tokenizer = reference_model_tokenizer |
160 | 176 | assert len(returns) == 4
|
161 |
| - assert all(r.result is not None for r in returns) |
162 |
| - |
| 177 | + for req, res in zip(requests, returns): |
| 178 | + is_chat = not isinstance(req.context, str) |
| 179 | + sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice |
| 180 | + tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids'] |
| 181 | + |
| 182 | + output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) |
| 183 | + with torch.no_grad(): |
| 184 | + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) |
| 185 | + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)) |
| 186 | + context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context)) |
| 187 | + continuation_logprob = logprobs[:, context_length-1:].sum() |
| 188 | + |
| 189 | + tokenized_choice = tokenized_sequence[:, context_length:] |
| 190 | + assert tokenized_choice[0].tolist() == res.input_tokens |
| 191 | + assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob) |
| 192 | + |
| 193 | + def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): |
| 194 | + model, tokenizer = reference_model_tokenizer |
| 195 | + requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING] |
163 | 196 | returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING])
|
164 | 197 | assert len(returns) == 2
|
165 |
| - assert all(r.result is not None for r in returns) |
| 198 | + for req, res in zip(requests, returns): |
| 199 | + is_chat = not isinstance(req.context, str) |
| 200 | + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] |
| 201 | + output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) |
| 202 | + with torch.no_grad(): |
| 203 | + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) |
| 204 | + logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum() |
| 205 | + |
| 206 | + assert tokenized_context[0, 1:].tolist() == res.input_tokens |
| 207 | + assert torch.allclose(torch.tensor(res.result), logprob) |
166 | 208 |
|
167 | 209 | @pytest.mark.parametrize("num_fewshot", [0, 2])
|
168 | 210 | @pytest.mark.parametrize("use_chat_template", [False, True])
|
169 |
| - def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): |
170 |
| - env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) |
| 211 | + def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): |
| 212 | + #TODO |
171 | 213 | evaluation_tracker = EvaluationTracker()
|
172 | 214 | pipeline_params = PipelineParameters(
|
173 | 215 | launcher_type=ParallelismManager.NONE,
|
174 |
| - env_config=env_config, |
175 | 216 | use_chat_template=use_chat_template,
|
176 | 217 | )
|
177 | 218 |
|
|
0 commit comments