Skip to content

Commit 551e800

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI SDK client - Add batch_evaluate method for asynchronous batch eval. Add transformation support for consistent interface parameters with the evaluate method
PiperOrigin-RevId: 774952915
1 parent e728d8b commit 551e800

File tree

5 files changed

+1649
-1544
lines changed

5 files changed

+1649
-1544
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,13 @@ def test_eval_run(self):
158158

159159
@pytest.mark.usefixtures("google_auth_mock")
160160
@mock.patch.object(client.Client, "_get_api_client")
161-
@mock.patch.object(evals.Evals, "batch_eval")
162-
def test_eval_batch_eval(self, mock_evaluate, mock_get_api_client):
161+
@mock.patch.object(evals.Evals, "batch_evaluate")
162+
def test_eval_batch_evaluate(self, mock_evaluate, mock_get_api_client):
163163
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
164-
test_client.evals.batch_eval(
164+
test_client.evals.batch_evaluate(
165165
dataset=vertexai_genai_types.EvaluationDataset(),
166166
metrics=[vertexai_genai_types.Metric(name="test")],
167-
output_config=vertexai_genai_types.OutputConfig(),
168-
autorater_config=vertexai_genai_types.AutoraterConfig(),
167+
dest="gs://bucket/output",
169168
config=vertexai_genai_types.EvaluateDatasetConfig(),
170169
)
171170
mock_evaluate.assert_called_once()
@@ -376,13 +375,14 @@ def test_inference_with_gcs_destination(
376375
mock_generate_content_response
377376
)
378377

379-
gcs_dest_path = "gs://bucket/output.jsonl"
380-
config = vertexai_genai_types.EvalRunInferenceConfig(dest=gcs_dest_path)
378+
gcs_dest_dir = "gs://bucket/output"
379+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=gcs_dest_dir)
381380

382381
inference_result = self.client.evals.run_inference(
383382
model="gemini-pro", src=mock_df, config=config
384383
)
385384

385+
expected_gcs_path = os.path.join(gcs_dest_dir, "inference_results.jsonl")
386386
expected_df_to_save = pd.DataFrame(
387387
{
388388
"prompt": ["test prompt"],
@@ -393,15 +393,15 @@ def test_inference_with_gcs_destination(
393393
pd.testing.assert_frame_equal(saved_df, expected_df_to_save)
394394
mock_gcs_utils.return_value.upload_dataframe.assert_called_once_with(
395395
df=mock.ANY,
396-
gcs_destination_blob_path=gcs_dest_path,
396+
gcs_destination_blob_path=expected_gcs_path,
397397
file_type="jsonl",
398398
)
399399
pd.testing.assert_frame_equal(
400400
inference_result.eval_dataset_df, expected_df_to_save
401401
)
402402
assert inference_result.candidate_name == "gemini-pro"
403403
assert inference_result.gcs_source == vertexai_genai_types.GcsSource(
404-
uris=[gcs_dest_path]
404+
uris=[expected_gcs_path]
405405
)
406406

407407
@mock.patch.object(_evals_common, "Models")
@@ -434,16 +434,17 @@ def test_inference_with_local_destination(
434434
mock_generate_content_response
435435
)
436436

437-
local_dest_path = "/tmp/test/output_dir/results.jsonl"
438-
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_path)
437+
local_dest_dir = "/tmp/test/output_dir"
438+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_dir)
439439

440440
inference_result = self.client.evals.run_inference(
441441
model="gemini-pro", src=mock_df, config=config
442442
)
443443

444-
mock_makedirs.assert_called_once_with("/tmp/test/output_dir", exist_ok=True)
444+
mock_makedirs.assert_called_once_with(local_dest_dir, exist_ok=True)
445+
expected_save_path = os.path.join(local_dest_dir, "inference_results.jsonl")
445446
mock_df_to_json.assert_called_once_with(
446-
local_dest_path, orient="records", lines=True
447+
expected_save_path, orient="records", lines=True
447448
)
448449
expected_df = pd.DataFrame(
449450
{
@@ -457,7 +458,7 @@ def test_inference_with_local_destination(
457458

458459
@mock.patch.object(_evals_common, "Models")
459460
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
460-
def test_inference_from_request_column_save_locally(
461+
def test_inference_from_request_column_save_to_local_dir(
461462
self, mock_eval_dataset_loader, mock_models
462463
):
463464
mock_df = pd.DataFrame(
@@ -494,8 +495,8 @@ def test_inference_from_request_column_save_locally(
494495
mock_generate_content_responses
495496
)
496497

497-
local_dest_path = "/tmp/output.jsonl"
498-
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_path)
498+
local_dest_dir = "/tmp/test_output_dir"
499+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_dir)
499500

500501
inference_result = self.client.evals.run_inference(
501502
model="gemini-pro", src=mock_df, config=config
@@ -530,13 +531,15 @@ def test_inference_from_request_column_save_locally(
530531
expected_df.sort_values(by="request").reset_index(drop=True),
531532
)
532533

533-
with open(local_dest_path, "r") as f:
534+
saved_file_path = os.path.join(local_dest_dir, "inference_results.jsonl")
535+
with open(saved_file_path, "r") as f:
534536
saved_records = [json.loads(line) for line in f]
535537
expected_records = expected_df.to_dict(orient="records")
536538
assert sorted(saved_records, key=lambda x: x["request"]) == sorted(
537539
expected_records, key=lambda x: x["request"]
538540
)
539-
os.remove(local_dest_path)
541+
os.remove(saved_file_path)
542+
os.rmdir(local_dest_dir)
540543
assert inference_result.candidate_name == "gemini-pro"
541544
assert inference_result.gcs_source is None
542545

vertexai/_genai/_evals_common.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -617,19 +617,13 @@ def _execute_inference(
617617

618618
if dest:
619619
file_name = "inference_results.jsonl"
620-
full_dest_path = dest
621620
is_gcs_path = dest.startswith(_evals_utils.GCS_PREFIX)
622621

623622
if is_gcs_path:
624-
if not dest.endswith("/"):
625-
pass
626-
else:
627-
full_dest_path = os.path.join(dest, file_name)
623+
full_dest_path = os.path.join(dest, file_name)
628624
else:
629-
if os.path.isdir(dest):
630-
full_dest_path = os.path.join(dest, file_name)
631-
632-
os.makedirs(os.path.dirname(full_dest_path), exist_ok=True)
625+
os.makedirs(dest, exist_ok=True)
626+
full_dest_path = os.path.join(dest, file_name)
633627

634628
logger.info("Saving inference results to: %s", full_dest_path)
635629
try:

0 commit comments

Comments
 (0)