Skip to content

Use logprobs in TLMChatCompletion #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/cleanlab_tlm/utils/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from typing import TYPE_CHECKING, Any, Optional, cast

import numpy as np

from cleanlab_tlm.internal.base import BaseTLM
from cleanlab_tlm.internal.constants import (
_DEFAULT_TLM_QUALITY_PRESET,
Expand Down Expand Up @@ -89,7 +91,16 @@ def score(
prompt_text = form_prompt_string(messages, tools)
response_text = _get_string_response(response)

return cast(TLMScore, self._tlm.get_trustworthiness_score(prompt_text, response_text))
scoring_kwargs = {}
# add perplexity to tlm.get_trustworthiness_score kwargs if it exists
perplexity = _extract_perplexity(response)
if perplexity is not None:
scoring_kwargs["perplexity"] = perplexity

return cast(
TLMScore,
self._tlm.get_trustworthiness_score(prompt_text, response_text, **scoring_kwargs),
)


def _get_string_response(response: "ChatCompletion") -> str:
Expand All @@ -105,3 +116,13 @@ def _get_string_response(response: "ChatCompletion") -> str:
if response.choices[0].message.content is None:
raise ValueError("The OpenAI ChatCompletion object does not contain a message content.")
return str(response.choices[0].message.content)


def _extract_perplexity(response: "ChatCompletion") -> Optional[float]:
response_logprobs = response.choices[0].logprobs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't ever cause an error for any OpenAI LLM?
I'd imagine its safer to check if the key logprobs exists

if response_logprobs is None or response_logprobs.content is None:
return None

logprobs_list = [completion.logprob for completion in response_logprobs.content]
perplexity = np.mean(np.exp(logprobs_list))
return float(perplexity)
82 changes: 80 additions & 2 deletions tests/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
ChatCompletion,
ChatCompletionMessage,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob

from cleanlab_tlm.internal.types import TLMQualityPreset
from cleanlab_tlm.utils.chat_completions import TLMChatCompletion
from cleanlab_tlm.utils.chat_completions import TLMChatCompletion, _extract_perplexity
from tests.conftest import make_text_unique
from tests.constants import TEST_PROMPT, TEST_RESPONSE
from tests.test_get_trustworthiness_score import is_trustworthiness_score_json_format
Expand Down Expand Up @@ -69,6 +70,8 @@ def test_tlm_chat_completion_score_with_options() -> None:

assert score is not None
assert is_trustworthiness_score_json_format(score)
assert score["log"]["explanation"] is not None
assert score["log"]["perplexity"] is None


def test_tlm_chat_completion_score_with_tools() -> None:
Expand Down Expand Up @@ -116,6 +119,81 @@ def test_tlm_chat_completion_score_with_tools() -> None:
assert is_trustworthiness_score_json_format(score)


def test_tlm_chat_completion_score_with_perplexity() -> None:
tlm_chat = TLMChatCompletion(options={"log": ["perplexity"]})
openai_kwargs = {
"model": "gpt-4.1-mini",
"messages": [{"role": "user", "content": test_prompt}],
}
response = ChatCompletion(
id="test",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(role="assistant", content=test_response),
finish_reason="stop",
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token="The", # noqa: S106
bytes=[84, 104, 101],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" capital", # noqa: S106
bytes=[32, 99, 97, 112, 105, 116, 97, 108],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" of", # noqa: S106
bytes=[32, 111, 102],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" France", # noqa: S106
bytes=[32, 70, 114, 97, 110, 99, 101],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" is", # noqa: S106
bytes=[32, 105, 115],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" Paris", # noqa: S106
bytes=[32, 80, 97, 114, 105, 115],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=".", # noqa: S106
bytes=[46],
logprob=-1.9361264946837764e-07,
top_logprobs=[],
),
],
refusal=None,
),
)
],
created=1234567890,
model="test-model",
object="chat.completion",
)

manually_calculated_perplexity = _extract_perplexity(response)

score = tlm_chat.score(response=response, **openai_kwargs)
returned_perplexity = score["log"]["perplexity"]

assert manually_calculated_perplexity == returned_perplexity


def test_tlm_chat_completion_score_invalid_response() -> None:
tlm_chat = TLMChatCompletion()
openai_kwargs = {
Expand Down