diff --git a/src/cleanlab_tlm/utils/chat_completions.py b/src/cleanlab_tlm/utils/chat_completions.py index 63936e2..0252fd1 100644 --- a/src/cleanlab_tlm/utils/chat_completions.py +++ b/src/cleanlab_tlm/utils/chat_completions.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, cast +import numpy as np + from cleanlab_tlm.internal.base import BaseTLM from cleanlab_tlm.internal.constants import ( _DEFAULT_TLM_QUALITY_PRESET, @@ -89,7 +91,16 @@ def score( prompt_text = form_prompt_string(messages, tools) response_text = _get_string_response(response) - return cast(TLMScore, self._tlm.get_trustworthiness_score(prompt_text, response_text)) + scoring_kwargs = {} + # add perplexity to tlm.get_trustworthiness_score kwargs if it exists + perplexity = _extract_perplexity(response) + if perplexity is not None: + scoring_kwargs["perplexity"] = perplexity + + return cast( + TLMScore, + self._tlm.get_trustworthiness_score(prompt_text, response_text, **scoring_kwargs), + ) def _get_string_response(response: "ChatCompletion") -> str: @@ -105,3 +116,13 @@ def _get_string_response(response: "ChatCompletion") -> str: if response.choices[0].message.content is None: raise ValueError("The OpenAI ChatCompletion object does not contain a message content.") return str(response.choices[0].message.content) + + +def _extract_perplexity(response: "ChatCompletion") -> Optional[float]: + response_logprobs = response.choices[0].logprobs + if response_logprobs is None or response_logprobs.content is None: + return None + + logprobs_list = [completion.logprob for completion in response_logprobs.content] + perplexity = np.mean(np.exp(logprobs_list)) + return float(perplexity) diff --git a/tests/test_chat_completions.py b/tests/test_chat_completions.py index f859cda..2687ace 100644 --- a/tests/test_chat_completions.py +++ b/tests/test_chat_completions.py @@ -3,10 +3,11 @@ ChatCompletion, ChatCompletionMessage, ) -from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion import Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob from cleanlab_tlm.internal.types import TLMQualityPreset -from cleanlab_tlm.utils.chat_completions import TLMChatCompletion +from cleanlab_tlm.utils.chat_completions import TLMChatCompletion, _extract_perplexity from tests.conftest import make_text_unique from tests.constants import TEST_PROMPT, TEST_RESPONSE from tests.test_get_trustworthiness_score import is_trustworthiness_score_json_format @@ -69,6 +70,8 @@ def test_tlm_chat_completion_score_with_options() -> None: assert score is not None assert is_trustworthiness_score_json_format(score) + assert score["log"]["explanation"] is not None + assert score["log"]["perplexity"] is None def test_tlm_chat_completion_score_with_tools() -> None: @@ -116,6 +119,81 @@ def test_tlm_chat_completion_score_with_tools() -> None: assert is_trustworthiness_score_json_format(score) +def test_tlm_chat_completion_score_with_perplexity() -> None: + tlm_chat = TLMChatCompletion(options={"log": ["perplexity"]}) + openai_kwargs = { + "model": "gpt-4.1-mini", + "messages": [{"role": "user", "content": test_prompt}], + } + response = ChatCompletion( + id="test", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage(role="assistant", content=test_response), + finish_reason="stop", + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token="The", # noqa: S106 + bytes=[84, 104, 101], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=" capital", # noqa: S106 + bytes=[32, 99, 97, 112, 105, 116, 97, 108], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=" of", # noqa: S106 + bytes=[32, 111, 102], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=" France", # noqa: S106 + bytes=[32, 70, 114, 97, 110, 99, 101], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=" is", # noqa: S106 + bytes=[32, 105, 115], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=" Paris", # noqa: S106 + bytes=[32, 80, 97, 114, 105, 115], + logprob=0.0, + top_logprobs=[], + ), + ChatCompletionTokenLogprob( + token=".", # noqa: S106 + bytes=[46], + logprob=-1.9361264946837764e-07, + top_logprobs=[], + ), + ], + refusal=None, + ), + ) + ], + created=1234567890, + model="test-model", + object="chat.completion", + ) + + manually_calculated_perplexity = _extract_perplexity(response) + + score = tlm_chat.score(response=response, **openai_kwargs) + returned_perplexity = score["log"]["perplexity"] + + assert manually_calculated_perplexity == returned_perplexity + + def test_tlm_chat_completion_score_invalid_response() -> None: tlm_chat = TLMChatCompletion() openai_kwargs = {