Skip to content

Commit 9fd40ae

Browse files
jsondaicopybara-github
authored andcommitted
chore: Add test_batch_eval to replay tests
PiperOrigin-RevId: 776231295
1 parent e8d18b6 commit 9fd40ae

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
20+
21+
def test_batch_eval(client):
22+
eval_dataset = types.EvaluationDataset(
23+
gcs_source=types.GcsSource(
24+
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
25+
)
26+
)
27+
28+
batch_eval_operation = client.evals.batch_evaluate(
29+
dataset=eval_dataset,
30+
metrics=[
31+
types.PrebuiltMetric.TEXT_QUALITY,
32+
],
33+
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
34+
)
35+
assert "operations" in batch_eval_operation.name
36+
assert "EvaluateDatasetOperationMetadata" in batch_eval_operation.metadata.get(
37+
"@type"
38+
)
39+
40+
41+
pytestmark = pytest_helper.setup(
42+
file=__file__,
43+
globals_for_file=globals(),
44+
test_method="evals.batch_evaluate",
45+
)

vertexai/_genai/_evals_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def read_file_contents(self, gcs_filepath: str) -> str:
182182
)
183183
bucket = self.storage_client.bucket(bucket_name)
184184
blob = bucket.blob(blob_path)
185-
content = blob.download_as_string().decode("utf-8")
185+
content = blob.download_as_bytes().decode("utf-8")
186186
logger.info(f"Successfully read content from '{gcs_filepath}'")
187187
return content
188188

0 commit comments

Comments
 (0)