Skip to content

Commit e728d8b

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI SDK client (evals) - add support for third-party model inference via litellm library
This change introduces support for running inference on third-party models, such as those from OpenAI, by integrating the `litellm` library. Users can now call `run_inference` with model strings (e.g., "gpt-4o") to generate responses from external third-party models. The function supports input datasets in the OpenAI Chat Completion format PiperOrigin-RevId: 775409256
1 parent 6a7a451 commit e728d8b

File tree

4 files changed

+432
-96
lines changed

4 files changed

+432
-96
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
"jsonschema",
173173
"ruamel.yaml",
174174
"pyyaml",
175+
"litellm >= 1.72.4",
175176
]
176177

177178
langchain_extra_require = [

tests/unit/vertexai/genai/test_evals.py

Lines changed: 252 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import statistics
1919
from unittest import mock
20+
import google.auth.credentials
2021
import warnings
2122

2223
from google.cloud import aiplatform
@@ -44,6 +45,84 @@
4445
pytestmark = pytest.mark.usefixtures("google_auth_mock")
4546

4647

48+
@pytest.fixture
49+
def mock_api_client_fixture():
50+
mock_client = mock.Mock(spec=client.Client)
51+
mock_client.project = _TEST_PROJECT
52+
mock_client.location = _TEST_LOCATION
53+
mock_client._credentials = mock.create_autospec(
54+
google.auth.credentials.Credentials, instance=True
55+
)
56+
mock_client._credentials.universe_domain = "googleapis.com"
57+
mock_client._evals_client = mock.Mock(spec=evals.Evals)
58+
return mock_client
59+
60+
61+
@pytest.fixture
62+
def mock_eval_dependencies(mock_api_client_fixture):
63+
with mock.patch("google.cloud.storage.Client") as mock_storage_client, mock.patch(
64+
"google.cloud.bigquery.Client"
65+
) as mock_bq_client, mock.patch(
66+
"vertexai._genai.evals.Evals.evaluate_instances"
67+
) as mock_evaluate_instances, mock.patch(
68+
"vertexai._genai._evals_utils.GcsUtils.upload_json_to_prefix"
69+
) as mock_upload_to_gcs, mock.patch(
70+
"vertexai._genai._evals_utils.LazyLoadedPrebuiltMetric._fetch_and_parse"
71+
) as mock_fetch_prebuilt_metric:
72+
73+
def mock_evaluate_instances_side_effect(*args, **kwargs):
74+
metric_config = kwargs.get("metric_config", {})
75+
if "exact_match_input" in metric_config:
76+
return vertexai_genai_types.EvaluateInstancesResponse(
77+
exact_match_results=vertexai_genai_types.ExactMatchResults(
78+
exact_match_metric_values=[
79+
vertexai_genai_types.ExactMatchMetricValue(score=1.0)
80+
]
81+
)
82+
)
83+
elif "rouge_input" in metric_config:
84+
return vertexai_genai_types.EvaluateInstancesResponse(
85+
rouge_results=vertexai_genai_types.RougeResults(
86+
rouge_metric_values=[
87+
vertexai_genai_types.RougeMetricValue(score=0.8)
88+
]
89+
)
90+
)
91+
elif "pointwise_metric_input" in metric_config:
92+
return vertexai_genai_types.EvaluateInstancesResponse(
93+
pointwise_metric_result=vertexai_genai_types.PointwiseMetricResult(
94+
score=0.9, explanation="Mocked LLM explanation"
95+
)
96+
)
97+
elif "comet_input" in metric_config:
98+
return vertexai_genai_types.EvaluateInstancesResponse(
99+
comet_result=vertexai_genai_types.CometResult(score=0.75)
100+
)
101+
return vertexai_genai_types.EvaluateInstancesResponse()
102+
103+
mock_evaluate_instances.side_effect = mock_evaluate_instances_side_effect
104+
mock_upload_to_gcs.return_value = (
105+
"gs://mock-bucket/mock_path/evaluation_result_timestamp.json"
106+
)
107+
mock_prebuilt_safety_metric = vertexai_genai_types.LLMMetric(
108+
name="safety", prompt_template="Is this safe? {response}"
109+
)
110+
mock_prebuilt_safety_metric._is_predefined = True
111+
mock_prebuilt_safety_metric._config_source = "gs://mock-metrics/safety/v1.yaml"
112+
mock_prebuilt_safety_metric._version = "v1"
113+
114+
mock_fetch_prebuilt_metric.return_value = mock_prebuilt_safety_metric
115+
116+
yield {
117+
"mock_storage_client": mock_storage_client,
118+
"mock_bq_client": mock_bq_client,
119+
"mock_evaluate_instances": mock_evaluate_instances,
120+
"mock_upload_to_gcs": mock_upload_to_gcs,
121+
"mock_fetch_prebuilt_metric": mock_fetch_prebuilt_metric,
122+
"mock_prebuilt_safety_metric": mock_prebuilt_safety_metric,
123+
}
124+
125+
47126
class TestEvals:
48127
"""Unit tests for the GenAI client."""
49128

@@ -716,26 +795,38 @@ def mock_generate_content_logic(*args, **kwargs):
716795
mock.call(
717796
model="gemini-pro",
718797
contents=[
719-
{"parts": [{"text": "Placeholder prompt 1"}], "role": "user"}
798+
{
799+
"parts": [{"text": "Placeholder prompt 1"}],
800+
"role": "user",
801+
}
720802
],
721803
config=genai_types.GenerateContentConfig(),
722804
),
723805
mock.call(
724806
model="gemini-pro",
725807
contents=[
726-
{"parts": [{"text": "Placeholder prompt 2.1"}], "role": "user"},
808+
{
809+
"parts": [{"text": "Placeholder prompt 2.1"}],
810+
"role": "user",
811+
},
727812
{
728813
"parts": [{"text": "Placeholder model response 2.1"}],
729814
"role": "model",
730815
},
731-
{"parts": [{"text": "Placeholder prompt 2.2"}], "role": "user"},
816+
{
817+
"parts": [{"text": "Placeholder prompt 2.2"}],
818+
"role": "user",
819+
},
732820
],
733821
config=genai_types.GenerateContentConfig(temperature=0.7, top_k=5),
734822
),
735823
mock.call(
736824
model="gemini-pro",
737825
contents=[
738-
{"parts": [{"text": "Placeholder prompt 3"}], "role": "user"}
826+
{
827+
"parts": [{"text": "Placeholder prompt 3"}],
828+
"role": "user",
829+
}
739830
],
740831
config=genai_types.GenerateContentConfig(),
741832
),
@@ -858,6 +949,163 @@ def test_inference_with_multimodal_content(
858949
assert inference_result.candidate_name == "gemini-pro"
859950
assert inference_result.gcs_source is None
860951

952+
def test_run_inference_with_litellm_string_prompt_format(
953+
self,
954+
mock_api_client_fixture,
955+
):
956+
"""Tests inference with LiteLLM using a simple prompt string."""
957+
with mock.patch(
958+
"vertexai._genai._evals_common.litellm"
959+
) as mock_litellm, mock.patch(
960+
"vertexai._genai._evals_common._call_litellm_completion"
961+
) as mock_call_litellm_completion:
962+
mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"]
963+
prompt_df = pd.DataFrame([{"prompt": "What is LiteLLM?"}])
964+
expected_messages = [{"role": "user", "content": "What is LiteLLM?"}]
965+
966+
mock_response_dict = {
967+
"id": "test",
968+
"created": 123456,
969+
"model": "gpt-4o",
970+
"object": "chat.completion",
971+
"system_fingerprint": "123456",
972+
"choices": [
973+
{
974+
"finish_reason": "stop",
975+
"index": 0,
976+
"message": {
977+
"content": "LiteLLM is a library...",
978+
"role": "assistant",
979+
"annotations": [],
980+
},
981+
"provider_specific_fields": {},
982+
}
983+
],
984+
"usage": {
985+
"completion_tokens": 114,
986+
"prompt_tokens": 13,
987+
"total_tokens": 127,
988+
},
989+
"service_tier": "default",
990+
}
991+
mock_call_litellm_completion.return_value = mock_response_dict
992+
evals_module = evals.Evals(api_client_=mock_api_client_fixture)
993+
994+
result_dataset = evals_module.run_inference(
995+
model="gpt-4o",
996+
src=prompt_df,
997+
)
998+
999+
mock_call_litellm_completion.assert_called_once()
1000+
_, call_kwargs = mock_call_litellm_completion.call_args
1001+
1002+
assert call_kwargs["model"] == "gpt-4o"
1003+
assert call_kwargs["messages"] == expected_messages
1004+
assert "response" in result_dataset.eval_dataset_df.columns
1005+
response_content = json.loads(result_dataset.eval_dataset_df["response"][0])
1006+
assert (
1007+
response_content["choices"][0]["message"]["content"]
1008+
== "LiteLLM is a library..."
1009+
)
1010+
1011+
def test_run_inference_with_litellm_openai_request_format(
1012+
self,
1013+
mock_api_client_fixture,
1014+
):
1015+
"""Tests inference with LiteLLM where the row contains an chat completion request body."""
1016+
with mock.patch(
1017+
"vertexai._genai._evals_common.litellm"
1018+
) as mock_litellm, mock.patch(
1019+
"vertexai._genai._evals_common._call_litellm_completion"
1020+
) as mock_call_litellm_completion:
1021+
mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"]
1022+
prompt_df = pd.DataFrame(
1023+
[
1024+
{
1025+
"model": "gpt-4o",
1026+
"messages": [
1027+
{
1028+
"role": "system",
1029+
"content": "You are a helpful assistant.",
1030+
},
1031+
{"role": "user", "content": "Hello!"},
1032+
],
1033+
}
1034+
]
1035+
)
1036+
expected_messages = [
1037+
{"role": "system", "content": "You are a helpful assistant."},
1038+
{"role": "user", "content": "Hello!"},
1039+
]
1040+
1041+
mock_response_dict = {
1042+
"id": "test",
1043+
"created": 123456,
1044+
"model": "gpt-4o",
1045+
"object": "chat.completion",
1046+
"system_fingerprint": "123456",
1047+
"choices": [
1048+
{
1049+
"finish_reason": "stop",
1050+
"index": 0,
1051+
"message": {
1052+
"content": "Hello there",
1053+
"role": "assistant",
1054+
"annotations": [],
1055+
},
1056+
"provider_specific_fields": {},
1057+
}
1058+
],
1059+
"usage": {
1060+
"completion_tokens": 114,
1061+
"prompt_tokens": 13,
1062+
"total_tokens": 127,
1063+
},
1064+
"service_tier": "default",
1065+
}
1066+
mock_call_litellm_completion.return_value = mock_response_dict
1067+
evals_module = evals.Evals(api_client_=mock_api_client_fixture)
1068+
1069+
result_dataset = evals_module.run_inference(
1070+
model="gpt-4o",
1071+
src=prompt_df,
1072+
)
1073+
1074+
mock_call_litellm_completion.assert_called_once()
1075+
_, call_kwargs = mock_call_litellm_completion.call_args
1076+
1077+
assert call_kwargs["model"] == "gpt-4o"
1078+
assert call_kwargs["messages"] == expected_messages
1079+
assert "response" in result_dataset.eval_dataset_df.columns
1080+
response_content = json.loads(result_dataset.eval_dataset_df["response"][0])
1081+
assert response_content["choices"][0]["message"]["content"] == "Hello there"
1082+
1083+
def test_run_inference_with_unsupported_model_string(
1084+
self,
1085+
mock_api_client_fixture,
1086+
):
1087+
with mock.patch(
1088+
"vertexai._genai._evals_common.litellm"
1089+
) as mock_litellm_package:
1090+
mock_litellm_package.utils.get_valid_models.return_value = []
1091+
evals_module = evals.Evals(api_client_=mock_api_client_fixture)
1092+
prompt_df = pd.DataFrame([{"prompt": "test"}])
1093+
1094+
with pytest.raises(TypeError, match="Unsupported string model name"):
1095+
evals_module.run_inference(
1096+
model="some-random-model/name", src=prompt_df
1097+
)
1098+
1099+
@mock.patch("vertexai._genai._evals_common.litellm", None)
1100+
def test_run_inference_with_litellm_import_error(self, mock_api_client_fixture):
1101+
evals_module = evals.Evals(api_client_=mock_api_client_fixture)
1102+
prompt_df = pd.DataFrame([{"prompt": "test"}])
1103+
with pytest.raises(
1104+
ImportError,
1105+
match="The 'litellm' library is required to use third-party models",
1106+
):
1107+
evals_module.run_inference(model="gpt-4o", src=prompt_df)
1108+
8611109

8621110
class TestMetricPromptBuilder:
8631111
"""Unit tests for the MetricPromptBuilder class."""
@@ -2738,81 +2986,6 @@ def test_auto_detect_empty_dataset(self):
27382986
)
27392987

27402988

2741-
@pytest.fixture
2742-
def mock_api_client_fixture():
2743-
mock_client = mock.Mock(spec=client.Client)
2744-
mock_client.project = _TEST_PROJECT
2745-
mock_client.location = _TEST_LOCATION
2746-
mock_client._credentials = mock.Mock()
2747-
mock_client._evals_client = mock.Mock(spec=evals.Evals)
2748-
return mock_client
2749-
2750-
2751-
@pytest.fixture
2752-
def mock_eval_dependencies(mock_api_client_fixture):
2753-
with mock.patch("google.cloud.storage.Client") as mock_storage_client, mock.patch(
2754-
"google.cloud.bigquery.Client"
2755-
) as mock_bq_client, mock.patch(
2756-
"vertexai._genai.evals.Evals.evaluate_instances"
2757-
) as mock_evaluate_instances, mock.patch(
2758-
"vertexai._genai._evals_utils.GcsUtils.upload_json_to_prefix"
2759-
) as mock_upload_to_gcs, mock.patch(
2760-
"vertexai._genai._evals_utils.LazyLoadedPrebuiltMetric._fetch_and_parse"
2761-
) as mock_fetch_prebuilt_metric:
2762-
2763-
def mock_evaluate_instances_side_effect(*args, **kwargs):
2764-
metric_config = kwargs.get("metric_config", {})
2765-
if "exact_match_input" in metric_config:
2766-
return vertexai_genai_types.EvaluateInstancesResponse(
2767-
exact_match_results=vertexai_genai_types.ExactMatchResults(
2768-
exact_match_metric_values=[
2769-
vertexai_genai_types.ExactMatchMetricValue(score=1.0)
2770-
]
2771-
)
2772-
)
2773-
elif "rouge_input" in metric_config:
2774-
return vertexai_genai_types.EvaluateInstancesResponse(
2775-
rouge_results=vertexai_genai_types.RougeResults(
2776-
rouge_metric_values=[
2777-
vertexai_genai_types.RougeMetricValue(score=0.8)
2778-
]
2779-
)
2780-
)
2781-
elif "pointwise_metric_input" in metric_config:
2782-
return vertexai_genai_types.EvaluateInstancesResponse(
2783-
pointwise_metric_result=vertexai_genai_types.PointwiseMetricResult(
2784-
score=0.9, explanation="Mocked LLM explanation"
2785-
)
2786-
)
2787-
elif "comet_input" in metric_config:
2788-
return vertexai_genai_types.EvaluateInstancesResponse(
2789-
comet_result=vertexai_genai_types.CometResult(score=0.75)
2790-
)
2791-
return vertexai_genai_types.EvaluateInstancesResponse()
2792-
2793-
mock_evaluate_instances.side_effect = mock_evaluate_instances_side_effect
2794-
mock_upload_to_gcs.return_value = (
2795-
"gs://mock-bucket/mock_path/evaluation_result_timestamp.json"
2796-
)
2797-
mock_prebuilt_safety_metric = vertexai_genai_types.LLMMetric(
2798-
name="safety", prompt_template="Is this safe? {response}"
2799-
)
2800-
mock_prebuilt_safety_metric._is_predefined = True
2801-
mock_prebuilt_safety_metric._config_source = "gs://mock-metrics/safety/v1.yaml"
2802-
mock_prebuilt_safety_metric._version = "v1"
2803-
2804-
mock_fetch_prebuilt_metric.return_value = mock_prebuilt_safety_metric
2805-
2806-
yield {
2807-
"mock_storage_client": mock_storage_client,
2808-
"mock_bq_client": mock_bq_client,
2809-
"mock_evaluate_instances": mock_evaluate_instances,
2810-
"mock_upload_to_gcs": mock_upload_to_gcs,
2811-
"mock_fetch_prebuilt_metric": mock_fetch_prebuilt_metric,
2812-
"mock_prebuilt_safety_metric": mock_prebuilt_safety_metric,
2813-
}
2814-
2815-
28162989
class TestEvalsRunEvaluation:
28172990
"""Unit tests for the evaluate method."""
28182991

0 commit comments

Comments
 (0)