|
17 | 17 | import os
|
18 | 18 | import statistics
|
19 | 19 | from unittest import mock
|
| 20 | +import google.auth.credentials |
20 | 21 | import warnings
|
21 | 22 |
|
22 | 23 | from google.cloud import aiplatform
|
|
44 | 45 | pytestmark = pytest.mark.usefixtures("google_auth_mock")
|
45 | 46 |
|
46 | 47 |
|
| 48 | +@pytest.fixture |
| 49 | +def mock_api_client_fixture(): |
| 50 | + mock_client = mock.Mock(spec=client.Client) |
| 51 | + mock_client.project = _TEST_PROJECT |
| 52 | + mock_client.location = _TEST_LOCATION |
| 53 | + mock_client._credentials = mock.create_autospec( |
| 54 | + google.auth.credentials.Credentials, instance=True |
| 55 | + ) |
| 56 | + mock_client._credentials.universe_domain = "googleapis.com" |
| 57 | + mock_client._evals_client = mock.Mock(spec=evals.Evals) |
| 58 | + return mock_client |
| 59 | + |
| 60 | + |
| 61 | +@pytest.fixture |
| 62 | +def mock_eval_dependencies(mock_api_client_fixture): |
| 63 | + with mock.patch("google.cloud.storage.Client") as mock_storage_client, mock.patch( |
| 64 | + "google.cloud.bigquery.Client" |
| 65 | + ) as mock_bq_client, mock.patch( |
| 66 | + "vertexai._genai.evals.Evals.evaluate_instances" |
| 67 | + ) as mock_evaluate_instances, mock.patch( |
| 68 | + "vertexai._genai._evals_utils.GcsUtils.upload_json_to_prefix" |
| 69 | + ) as mock_upload_to_gcs, mock.patch( |
| 70 | + "vertexai._genai._evals_utils.LazyLoadedPrebuiltMetric._fetch_and_parse" |
| 71 | + ) as mock_fetch_prebuilt_metric: |
| 72 | + |
| 73 | + def mock_evaluate_instances_side_effect(*args, **kwargs): |
| 74 | + metric_config = kwargs.get("metric_config", {}) |
| 75 | + if "exact_match_input" in metric_config: |
| 76 | + return vertexai_genai_types.EvaluateInstancesResponse( |
| 77 | + exact_match_results=vertexai_genai_types.ExactMatchResults( |
| 78 | + exact_match_metric_values=[ |
| 79 | + vertexai_genai_types.ExactMatchMetricValue(score=1.0) |
| 80 | + ] |
| 81 | + ) |
| 82 | + ) |
| 83 | + elif "rouge_input" in metric_config: |
| 84 | + return vertexai_genai_types.EvaluateInstancesResponse( |
| 85 | + rouge_results=vertexai_genai_types.RougeResults( |
| 86 | + rouge_metric_values=[ |
| 87 | + vertexai_genai_types.RougeMetricValue(score=0.8) |
| 88 | + ] |
| 89 | + ) |
| 90 | + ) |
| 91 | + elif "pointwise_metric_input" in metric_config: |
| 92 | + return vertexai_genai_types.EvaluateInstancesResponse( |
| 93 | + pointwise_metric_result=vertexai_genai_types.PointwiseMetricResult( |
| 94 | + score=0.9, explanation="Mocked LLM explanation" |
| 95 | + ) |
| 96 | + ) |
| 97 | + elif "comet_input" in metric_config: |
| 98 | + return vertexai_genai_types.EvaluateInstancesResponse( |
| 99 | + comet_result=vertexai_genai_types.CometResult(score=0.75) |
| 100 | + ) |
| 101 | + return vertexai_genai_types.EvaluateInstancesResponse() |
| 102 | + |
| 103 | + mock_evaluate_instances.side_effect = mock_evaluate_instances_side_effect |
| 104 | + mock_upload_to_gcs.return_value = ( |
| 105 | + "gs://mock-bucket/mock_path/evaluation_result_timestamp.json" |
| 106 | + ) |
| 107 | + mock_prebuilt_safety_metric = vertexai_genai_types.LLMMetric( |
| 108 | + name="safety", prompt_template="Is this safe? {response}" |
| 109 | + ) |
| 110 | + mock_prebuilt_safety_metric._is_predefined = True |
| 111 | + mock_prebuilt_safety_metric._config_source = "gs://mock-metrics/safety/v1.yaml" |
| 112 | + mock_prebuilt_safety_metric._version = "v1" |
| 113 | + |
| 114 | + mock_fetch_prebuilt_metric.return_value = mock_prebuilt_safety_metric |
| 115 | + |
| 116 | + yield { |
| 117 | + "mock_storage_client": mock_storage_client, |
| 118 | + "mock_bq_client": mock_bq_client, |
| 119 | + "mock_evaluate_instances": mock_evaluate_instances, |
| 120 | + "mock_upload_to_gcs": mock_upload_to_gcs, |
| 121 | + "mock_fetch_prebuilt_metric": mock_fetch_prebuilt_metric, |
| 122 | + "mock_prebuilt_safety_metric": mock_prebuilt_safety_metric, |
| 123 | + } |
| 124 | + |
| 125 | + |
47 | 126 | class TestEvals:
|
48 | 127 | """Unit tests for the GenAI client."""
|
49 | 128 |
|
@@ -716,26 +795,38 @@ def mock_generate_content_logic(*args, **kwargs):
|
716 | 795 | mock.call(
|
717 | 796 | model="gemini-pro",
|
718 | 797 | contents=[
|
719 |
| - {"parts": [{"text": "Placeholder prompt 1"}], "role": "user"} |
| 798 | + { |
| 799 | + "parts": [{"text": "Placeholder prompt 1"}], |
| 800 | + "role": "user", |
| 801 | + } |
720 | 802 | ],
|
721 | 803 | config=genai_types.GenerateContentConfig(),
|
722 | 804 | ),
|
723 | 805 | mock.call(
|
724 | 806 | model="gemini-pro",
|
725 | 807 | contents=[
|
726 |
| - {"parts": [{"text": "Placeholder prompt 2.1"}], "role": "user"}, |
| 808 | + { |
| 809 | + "parts": [{"text": "Placeholder prompt 2.1"}], |
| 810 | + "role": "user", |
| 811 | + }, |
727 | 812 | {
|
728 | 813 | "parts": [{"text": "Placeholder model response 2.1"}],
|
729 | 814 | "role": "model",
|
730 | 815 | },
|
731 |
| - {"parts": [{"text": "Placeholder prompt 2.2"}], "role": "user"}, |
| 816 | + { |
| 817 | + "parts": [{"text": "Placeholder prompt 2.2"}], |
| 818 | + "role": "user", |
| 819 | + }, |
732 | 820 | ],
|
733 | 821 | config=genai_types.GenerateContentConfig(temperature=0.7, top_k=5),
|
734 | 822 | ),
|
735 | 823 | mock.call(
|
736 | 824 | model="gemini-pro",
|
737 | 825 | contents=[
|
738 |
| - {"parts": [{"text": "Placeholder prompt 3"}], "role": "user"} |
| 826 | + { |
| 827 | + "parts": [{"text": "Placeholder prompt 3"}], |
| 828 | + "role": "user", |
| 829 | + } |
739 | 830 | ],
|
740 | 831 | config=genai_types.GenerateContentConfig(),
|
741 | 832 | ),
|
@@ -858,6 +949,163 @@ def test_inference_with_multimodal_content(
|
858 | 949 | assert inference_result.candidate_name == "gemini-pro"
|
859 | 950 | assert inference_result.gcs_source is None
|
860 | 951 |
|
| 952 | + def test_run_inference_with_litellm_string_prompt_format( |
| 953 | + self, |
| 954 | + mock_api_client_fixture, |
| 955 | + ): |
| 956 | + """Tests inference with LiteLLM using a simple prompt string.""" |
| 957 | + with mock.patch( |
| 958 | + "vertexai._genai._evals_common.litellm" |
| 959 | + ) as mock_litellm, mock.patch( |
| 960 | + "vertexai._genai._evals_common._call_litellm_completion" |
| 961 | + ) as mock_call_litellm_completion: |
| 962 | + mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"] |
| 963 | + prompt_df = pd.DataFrame([{"prompt": "What is LiteLLM?"}]) |
| 964 | + expected_messages = [{"role": "user", "content": "What is LiteLLM?"}] |
| 965 | + |
| 966 | + mock_response_dict = { |
| 967 | + "id": "test", |
| 968 | + "created": 123456, |
| 969 | + "model": "gpt-4o", |
| 970 | + "object": "chat.completion", |
| 971 | + "system_fingerprint": "123456", |
| 972 | + "choices": [ |
| 973 | + { |
| 974 | + "finish_reason": "stop", |
| 975 | + "index": 0, |
| 976 | + "message": { |
| 977 | + "content": "LiteLLM is a library...", |
| 978 | + "role": "assistant", |
| 979 | + "annotations": [], |
| 980 | + }, |
| 981 | + "provider_specific_fields": {}, |
| 982 | + } |
| 983 | + ], |
| 984 | + "usage": { |
| 985 | + "completion_tokens": 114, |
| 986 | + "prompt_tokens": 13, |
| 987 | + "total_tokens": 127, |
| 988 | + }, |
| 989 | + "service_tier": "default", |
| 990 | + } |
| 991 | + mock_call_litellm_completion.return_value = mock_response_dict |
| 992 | + evals_module = evals.Evals(api_client_=mock_api_client_fixture) |
| 993 | + |
| 994 | + result_dataset = evals_module.run_inference( |
| 995 | + model="gpt-4o", |
| 996 | + src=prompt_df, |
| 997 | + ) |
| 998 | + |
| 999 | + mock_call_litellm_completion.assert_called_once() |
| 1000 | + _, call_kwargs = mock_call_litellm_completion.call_args |
| 1001 | + |
| 1002 | + assert call_kwargs["model"] == "gpt-4o" |
| 1003 | + assert call_kwargs["messages"] == expected_messages |
| 1004 | + assert "response" in result_dataset.eval_dataset_df.columns |
| 1005 | + response_content = json.loads(result_dataset.eval_dataset_df["response"][0]) |
| 1006 | + assert ( |
| 1007 | + response_content["choices"][0]["message"]["content"] |
| 1008 | + == "LiteLLM is a library..." |
| 1009 | + ) |
| 1010 | + |
| 1011 | + def test_run_inference_with_litellm_openai_request_format( |
| 1012 | + self, |
| 1013 | + mock_api_client_fixture, |
| 1014 | + ): |
| 1015 | + """Tests inference with LiteLLM where the row contains an chat completion request body.""" |
| 1016 | + with mock.patch( |
| 1017 | + "vertexai._genai._evals_common.litellm" |
| 1018 | + ) as mock_litellm, mock.patch( |
| 1019 | + "vertexai._genai._evals_common._call_litellm_completion" |
| 1020 | + ) as mock_call_litellm_completion: |
| 1021 | + mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"] |
| 1022 | + prompt_df = pd.DataFrame( |
| 1023 | + [ |
| 1024 | + { |
| 1025 | + "model": "gpt-4o", |
| 1026 | + "messages": [ |
| 1027 | + { |
| 1028 | + "role": "system", |
| 1029 | + "content": "You are a helpful assistant.", |
| 1030 | + }, |
| 1031 | + {"role": "user", "content": "Hello!"}, |
| 1032 | + ], |
| 1033 | + } |
| 1034 | + ] |
| 1035 | + ) |
| 1036 | + expected_messages = [ |
| 1037 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 1038 | + {"role": "user", "content": "Hello!"}, |
| 1039 | + ] |
| 1040 | + |
| 1041 | + mock_response_dict = { |
| 1042 | + "id": "test", |
| 1043 | + "created": 123456, |
| 1044 | + "model": "gpt-4o", |
| 1045 | + "object": "chat.completion", |
| 1046 | + "system_fingerprint": "123456", |
| 1047 | + "choices": [ |
| 1048 | + { |
| 1049 | + "finish_reason": "stop", |
| 1050 | + "index": 0, |
| 1051 | + "message": { |
| 1052 | + "content": "Hello there", |
| 1053 | + "role": "assistant", |
| 1054 | + "annotations": [], |
| 1055 | + }, |
| 1056 | + "provider_specific_fields": {}, |
| 1057 | + } |
| 1058 | + ], |
| 1059 | + "usage": { |
| 1060 | + "completion_tokens": 114, |
| 1061 | + "prompt_tokens": 13, |
| 1062 | + "total_tokens": 127, |
| 1063 | + }, |
| 1064 | + "service_tier": "default", |
| 1065 | + } |
| 1066 | + mock_call_litellm_completion.return_value = mock_response_dict |
| 1067 | + evals_module = evals.Evals(api_client_=mock_api_client_fixture) |
| 1068 | + |
| 1069 | + result_dataset = evals_module.run_inference( |
| 1070 | + model="gpt-4o", |
| 1071 | + src=prompt_df, |
| 1072 | + ) |
| 1073 | + |
| 1074 | + mock_call_litellm_completion.assert_called_once() |
| 1075 | + _, call_kwargs = mock_call_litellm_completion.call_args |
| 1076 | + |
| 1077 | + assert call_kwargs["model"] == "gpt-4o" |
| 1078 | + assert call_kwargs["messages"] == expected_messages |
| 1079 | + assert "response" in result_dataset.eval_dataset_df.columns |
| 1080 | + response_content = json.loads(result_dataset.eval_dataset_df["response"][0]) |
| 1081 | + assert response_content["choices"][0]["message"]["content"] == "Hello there" |
| 1082 | + |
| 1083 | + def test_run_inference_with_unsupported_model_string( |
| 1084 | + self, |
| 1085 | + mock_api_client_fixture, |
| 1086 | + ): |
| 1087 | + with mock.patch( |
| 1088 | + "vertexai._genai._evals_common.litellm" |
| 1089 | + ) as mock_litellm_package: |
| 1090 | + mock_litellm_package.utils.get_valid_models.return_value = [] |
| 1091 | + evals_module = evals.Evals(api_client_=mock_api_client_fixture) |
| 1092 | + prompt_df = pd.DataFrame([{"prompt": "test"}]) |
| 1093 | + |
| 1094 | + with pytest.raises(TypeError, match="Unsupported string model name"): |
| 1095 | + evals_module.run_inference( |
| 1096 | + model="some-random-model/name", src=prompt_df |
| 1097 | + ) |
| 1098 | + |
| 1099 | + @mock.patch("vertexai._genai._evals_common.litellm", None) |
| 1100 | + def test_run_inference_with_litellm_import_error(self, mock_api_client_fixture): |
| 1101 | + evals_module = evals.Evals(api_client_=mock_api_client_fixture) |
| 1102 | + prompt_df = pd.DataFrame([{"prompt": "test"}]) |
| 1103 | + with pytest.raises( |
| 1104 | + ImportError, |
| 1105 | + match="The 'litellm' library is required to use third-party models", |
| 1106 | + ): |
| 1107 | + evals_module.run_inference(model="gpt-4o", src=prompt_df) |
| 1108 | + |
861 | 1109 |
|
862 | 1110 | class TestMetricPromptBuilder:
|
863 | 1111 | """Unit tests for the MetricPromptBuilder class."""
|
@@ -2738,81 +2986,6 @@ def test_auto_detect_empty_dataset(self):
|
2738 | 2986 | )
|
2739 | 2987 |
|
2740 | 2988 |
|
2741 |
| -@pytest.fixture |
2742 |
| -def mock_api_client_fixture(): |
2743 |
| - mock_client = mock.Mock(spec=client.Client) |
2744 |
| - mock_client.project = _TEST_PROJECT |
2745 |
| - mock_client.location = _TEST_LOCATION |
2746 |
| - mock_client._credentials = mock.Mock() |
2747 |
| - mock_client._evals_client = mock.Mock(spec=evals.Evals) |
2748 |
| - return mock_client |
2749 |
| - |
2750 |
| - |
2751 |
| -@pytest.fixture |
2752 |
| -def mock_eval_dependencies(mock_api_client_fixture): |
2753 |
| - with mock.patch("google.cloud.storage.Client") as mock_storage_client, mock.patch( |
2754 |
| - "google.cloud.bigquery.Client" |
2755 |
| - ) as mock_bq_client, mock.patch( |
2756 |
| - "vertexai._genai.evals.Evals.evaluate_instances" |
2757 |
| - ) as mock_evaluate_instances, mock.patch( |
2758 |
| - "vertexai._genai._evals_utils.GcsUtils.upload_json_to_prefix" |
2759 |
| - ) as mock_upload_to_gcs, mock.patch( |
2760 |
| - "vertexai._genai._evals_utils.LazyLoadedPrebuiltMetric._fetch_and_parse" |
2761 |
| - ) as mock_fetch_prebuilt_metric: |
2762 |
| - |
2763 |
| - def mock_evaluate_instances_side_effect(*args, **kwargs): |
2764 |
| - metric_config = kwargs.get("metric_config", {}) |
2765 |
| - if "exact_match_input" in metric_config: |
2766 |
| - return vertexai_genai_types.EvaluateInstancesResponse( |
2767 |
| - exact_match_results=vertexai_genai_types.ExactMatchResults( |
2768 |
| - exact_match_metric_values=[ |
2769 |
| - vertexai_genai_types.ExactMatchMetricValue(score=1.0) |
2770 |
| - ] |
2771 |
| - ) |
2772 |
| - ) |
2773 |
| - elif "rouge_input" in metric_config: |
2774 |
| - return vertexai_genai_types.EvaluateInstancesResponse( |
2775 |
| - rouge_results=vertexai_genai_types.RougeResults( |
2776 |
| - rouge_metric_values=[ |
2777 |
| - vertexai_genai_types.RougeMetricValue(score=0.8) |
2778 |
| - ] |
2779 |
| - ) |
2780 |
| - ) |
2781 |
| - elif "pointwise_metric_input" in metric_config: |
2782 |
| - return vertexai_genai_types.EvaluateInstancesResponse( |
2783 |
| - pointwise_metric_result=vertexai_genai_types.PointwiseMetricResult( |
2784 |
| - score=0.9, explanation="Mocked LLM explanation" |
2785 |
| - ) |
2786 |
| - ) |
2787 |
| - elif "comet_input" in metric_config: |
2788 |
| - return vertexai_genai_types.EvaluateInstancesResponse( |
2789 |
| - comet_result=vertexai_genai_types.CometResult(score=0.75) |
2790 |
| - ) |
2791 |
| - return vertexai_genai_types.EvaluateInstancesResponse() |
2792 |
| - |
2793 |
| - mock_evaluate_instances.side_effect = mock_evaluate_instances_side_effect |
2794 |
| - mock_upload_to_gcs.return_value = ( |
2795 |
| - "gs://mock-bucket/mock_path/evaluation_result_timestamp.json" |
2796 |
| - ) |
2797 |
| - mock_prebuilt_safety_metric = vertexai_genai_types.LLMMetric( |
2798 |
| - name="safety", prompt_template="Is this safe? {response}" |
2799 |
| - ) |
2800 |
| - mock_prebuilt_safety_metric._is_predefined = True |
2801 |
| - mock_prebuilt_safety_metric._config_source = "gs://mock-metrics/safety/v1.yaml" |
2802 |
| - mock_prebuilt_safety_metric._version = "v1" |
2803 |
| - |
2804 |
| - mock_fetch_prebuilt_metric.return_value = mock_prebuilt_safety_metric |
2805 |
| - |
2806 |
| - yield { |
2807 |
| - "mock_storage_client": mock_storage_client, |
2808 |
| - "mock_bq_client": mock_bq_client, |
2809 |
| - "mock_evaluate_instances": mock_evaluate_instances, |
2810 |
| - "mock_upload_to_gcs": mock_upload_to_gcs, |
2811 |
| - "mock_fetch_prebuilt_metric": mock_fetch_prebuilt_metric, |
2812 |
| - "mock_prebuilt_safety_metric": mock_prebuilt_safety_metric, |
2813 |
| - } |
2814 |
| - |
2815 |
| - |
2816 | 2989 | class TestEvalsRunEvaluation:
|
2817 | 2990 | """Unit tests for the evaluate method."""
|
2818 | 2991 |
|
|
0 commit comments