Skip to content

Commit 8c0018e

Browse files
Improve endpoint tests and bug fix in endpoint model
1 parent b291871 commit 8c0018e

File tree

6 files changed

+181
-36
lines changed

6 files changed

+181
-36
lines changed

src/lighteval/models/endpoint_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,10 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn
205205
def _process_logprob_response(
206206
self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest
207207
) -> LoglikelihoodResponse:
208-
len_choice = len(request.tokenized_continuation)
209-
logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]])
208+
logits = sum([t.logprob for t in response.details.prefill[len(request.tokenized_context):]])
210209
return LoglikelihoodResponse(
211210
result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits,
212-
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
211+
input_tokens=[t.id for t in response.details.prefill[len(request.tokenized_context):]],
213212
generated_tokens=-1,
214213
truncated_tokens_count=-1,
215214
padded_tokens_count=-1,
@@ -255,6 +254,7 @@ def _prepare_request(self, request: Request) -> EndpointInput:
255254
context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)]
256255
if not isinstance(context, str):
257256
context = self.tokenizer.apply_chat_template(context, tokenize=False)
257+
context = context.split(self.tokenizer.bos_token, 1)[-1]
258258

259259
if isinstance(context, str):
260260
prepared_request = TextGenerationInput(
@@ -290,6 +290,7 @@ def greedy_until(
290290
override_bs: Optional[int] = None,
291291
) -> List[GenerativeResponse]:
292292
for request in requests:
293+
# Why don't we set context to empty list here?
293294
request.tokenized_context = self.tok_encode(request.context)
294295
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
295296

src/lighteval/models/tgi_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None:
5353
self.model_info = ModelInfo(
5454
model_name=model_id or self.name,
5555
model_sha=info["model_sha"],
56-
model_dtype=info["model_dtype"] or "default",
56+
model_dtype=info["model_dtype"] if "model_dtype" in info else "default",
5757
model_size=-1,
5858
)
5959
self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name)

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Iterator
2+
import pytest
3+
4+
from lighteval.models.model_config import BaseModelConfig
5+
from lighteval.models.abstract_model import EnvConfig
6+
from lighteval.models.base_model import BaseModel
7+
8+
9+
@pytest.fixture(scope="module")
10+
def base_model() -> Iterator[BaseModel]:
11+
config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
12+
return BaseModel(config, EnvConfig())

tests/test_base_model.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,6 @@
4141
)
4242

4343

44-
TOKEN = os.environ.get("HF_TOKEN")
45-
CACHE_PATH = os.getenv("HF_HOME", ".")
46-
47-
48-
@pytest.fixture(scope="module")
49-
def base_model() -> Iterator[BaseModel]:
50-
config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
51-
return BaseModel(config, EnvConfig(CACHE_PATH, TOKEN))
52-
53-
5444
RequestDict: TypeAlias = dict[RequestType, list[Request]]
5545

5646

@@ -122,11 +112,9 @@ def task(self) -> LightevalTask:
122112
def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool):
123113
base_model.use_chat_template = use_chat_template
124114

125-
env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
126115
evaluation_tracker = EvaluationTracker()
127116
pipeline_params = PipelineParameters(
128117
launcher_type=ParallelismManager.NONE,
129-
env_config=env_config,
130118
use_chat_template=use_chat_template,
131119
override_batch_size=1,
132120
)

tests/test_endpoint_model.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,27 @@
2727
from typing import Iterator, TypeAlias
2828
from unittest.mock import patch
2929

30+
import torch
3031
import docker
3132
import docker.errors
3233
import pytest
3334
import requests
3435
from huggingface_hub import ChatCompletionInputMessage
36+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast
3537

3638
from lighteval.logging.evaluation_tracker import EvaluationTracker
3739
from lighteval.metrics.metrics import Metrics
3840
from lighteval.models.tgi_model import ModelClient as TGIModel
41+
from lighteval.models.base_model import BaseModel
3942
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
4043
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
4144
from lighteval.tasks.requests import (
45+
LoglikelihoodRequest,
46+
LoglikelihoodRollingRequest,
4247
Doc,
4348
Request,
4449
RequestType,
4550
)
46-
from lighteval.utils.utils import EnvConfig
47-
48-
49-
TOKEN = os.environ.get("HF_TOKEN")
50-
CACHE_PATH = os.getenv("HF_HOME", ".")
5151

5252

5353
@pytest.fixture(scope="module")
@@ -83,12 +83,19 @@ def tgi_model() -> Iterator[TGIModel]:
8383
raise RuntimeError("Couldn't setup TGI server.")
8484
model = TGIModel(address)
8585
yield model
86-
container.stop()
87-
container.wait()
88-
container.remove()
86+
# container.stop()
87+
# container.wait()
88+
# container.remove()
8989
model.cleanup()
9090

9191

92+
@pytest.fixture(scope="module")
93+
def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]:
94+
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
95+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
96+
return model, tokenizer
97+
98+
9299
RequestDict: TypeAlias = dict[RequestType, list[Request]]
93100

94101

@@ -150,28 +157,62 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict:
150157
result[req_type].extend(doc_result[req_type])
151158
return result
152159

153-
def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
154-
returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL])
160+
def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
161+
requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL]
162+
returns = tgi_model.greedy_until(requests)
163+
model, tokenizer = reference_model_tokenizer
155164
assert len(returns) == 2
156-
assert all(r.result is not None for r in returns)
157-
158-
def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
159-
returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD])
165+
for req, res in zip(requests, returns):
166+
is_chat = not isinstance(req.context, str)
167+
tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids']
168+
ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist()
169+
continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):]
170+
assert continuation == res.result
171+
172+
def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
173+
requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD]
174+
returns = tgi_model.loglikelihood(requests)
175+
model, tokenizer = reference_model_tokenizer
160176
assert len(returns) == 4
161-
assert all(r.result is not None for r in returns)
162-
177+
for req, res in zip(requests, returns):
178+
is_chat = not isinstance(req.context, str)
179+
sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice
180+
tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids']
181+
182+
output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
183+
with torch.no_grad():
184+
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
185+
logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1))
186+
context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context))
187+
continuation_logprob = logprobs[:, context_length-1:].sum()
188+
189+
tokenized_choice = tokenized_sequence[:, context_length:]
190+
assert tokenized_choice[0].tolist() == res.input_tokens
191+
assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob)
192+
193+
def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
194+
model, tokenizer = reference_model_tokenizer
195+
requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]
163196
returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING])
164197
assert len(returns) == 2
165-
assert all(r.result is not None for r in returns)
198+
for req, res in zip(requests, returns):
199+
is_chat = not isinstance(req.context, str)
200+
tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids']
201+
output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
202+
with torch.no_grad():
203+
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
204+
logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum()
205+
206+
assert tokenized_context[0, 1:].tolist() == res.input_tokens
207+
assert torch.allclose(torch.tensor(res.result), logprob)
166208

167209
@pytest.mark.parametrize("num_fewshot", [0, 2])
168210
@pytest.mark.parametrize("use_chat_template", [False, True])
169-
def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool):
170-
env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
211+
def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool):
212+
#TODO
171213
evaluation_tracker = EvaluationTracker()
172214
pipeline_params = PipelineParameters(
173215
launcher_type=ParallelismManager.NONE,
174-
env_config=env_config,
175216
use_chat_template=use_chat_template,
176217
)
177218

tests/test_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import time
2+
import random
3+
import asyncio
4+
from typing import Iterator
5+
6+
import pytest
7+
import docker
8+
import requests
9+
import torch
10+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
11+
from huggingface_hub import (
12+
InferenceClient,
13+
AsyncInferenceClient,
14+
TextGenerationOutput,
15+
)
16+
17+
18+
@pytest.fixture(params=["sync", "async"])
19+
def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]:
20+
client = docker.from_env()
21+
22+
try:
23+
container = client.containers.get("lighteval-tgi-model-test")
24+
port = container.ports["80/tcp"][0]["HostPort"]
25+
except docker.errors.NotFound:
26+
port = random.randint(8000, 9000)
27+
container = client.containers.run(
28+
"ghcr.io/huggingface/text-generation-inference:2.2.0",
29+
command=[
30+
"--model-id",
31+
"hf-internal-testing/tiny-random-LlamaForCausalLM",
32+
"--dtype",
33+
"float16",
34+
],
35+
detach=True,
36+
name="lighteval-tgi-model-test",
37+
auto_remove=False,
38+
ports={"80/tcp": port},
39+
)
40+
address = f"http://localhost:{port}"
41+
for _ in range(40):
42+
try:
43+
if requests.get(f"{address}/health"):
44+
break
45+
except Exception:
46+
time.sleep(1)
47+
else:
48+
raise RuntimeError("Couldn't setup TGI server.")
49+
50+
if request.param == "async":
51+
yield AsyncInferenceClient(base_url=address)
52+
elif request.param == "sync":
53+
yield InferenceClient(base_url=address)
54+
else:
55+
raise RuntimeError()
56+
57+
58+
def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient):
59+
model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
60+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
61+
62+
# It raises error in async setting unless the size of `prompts` is < 3
63+
prompts = [
64+
"Tell me:\n\nHow are you?Fine, thanks!",
65+
"Tell me:\n\nHow are you?Not bad!",
66+
"Tell me:\n\nComment vas-tu?Comme ci, comme ça",
67+
"Tell me:\n\nComment vas-tu?Ca va! Merci!",
68+
]
69+
responses = []
70+
for prompt in prompts:
71+
responses.append(tgi_client.text_generation(
72+
prompt,
73+
details=True,
74+
decoder_input_details=True,
75+
max_new_tokens=1,
76+
stop_sequences=None,
77+
do_sample=False,
78+
return_full_text=False,
79+
seed=42,
80+
))
81+
if isinstance(tgi_client, AsyncInferenceClient):
82+
loop = asyncio.get_event_loop()
83+
responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses))
84+
85+
error = False
86+
for prompt, response in zip(prompts, responses):
87+
88+
tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping <s> whose logprob is None
89+
90+
tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids']
91+
output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
92+
with torch.no_grad():
93+
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
94+
logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze()
95+
96+
if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()):
97+
print(f"====== prompt: {repr(prompt)} ======")
98+
print("TGI logprobs:", tgi_logprobs.tolist())
99+
print("TGI tokens:",[t.id for t in response.details.prefill])
100+
print("Ref. logprobs:", logprobs.tolist())
101+
print("Ref. tokens:", tokenized_sequence[0].tolist())
102+
error = True
103+
assert not error

0 commit comments

Comments
 (0)