@@ -297,6 +297,19 @@ def assess_data_batch_prediction_resources_mock():
297
297
yield assess_data_mock
298
298
299
299
300
+ @pytest .fixture
301
+ def assess_data_batch_prediction_validation_mock ():
302
+ with mock .patch .object (
303
+ dataset_service .DatasetServiceClient , "assess_data"
304
+ ) as assess_data_mock :
305
+ assess_data_lro_mock = mock .Mock (operation .Operation )
306
+ assess_data_lro_mock .result .return_value = gca_dataset_service .AssessDataResponse (
307
+ batch_prediction_validation_assessment_result = gca_dataset_service .AssessDataResponse .BatchPredictionValidationAssessmentResult ()
308
+ )
309
+ assess_data_mock .return_value = assess_data_lro_mock
310
+ yield assess_data_mock
311
+
312
+
300
313
@pytest .fixture
301
314
def assemble_data_mock ():
302
315
with mock .patch .object (
@@ -810,6 +823,33 @@ def test_assess_batch_prediction_resources_request_column_name(
810
823
timeout = None ,
811
824
)
812
825
826
+ @pytest .mark .usefixtures ("get_dataset_mock" )
827
+ def test_assess_batch_prediction_validity (
828
+ self , assess_data_batch_prediction_validation_mock
829
+ ):
830
+ aiplatform .init (project = _TEST_PROJECT )
831
+ dataset = ummd .MultimodalDataset (dataset_name = _TEST_NAME )
832
+ template_config = ummd .GeminiTemplateConfig (
833
+ field_mapping = {"question" : "questionColumn" },
834
+ )
835
+ result = dataset .assess_batch_prediction_validity (
836
+ model_name = "gemini-1.5-flash-exp" ,
837
+ template_config = template_config ,
838
+ )
839
+ assess_data_batch_prediction_validation_mock .assert_called_once_with (
840
+ request = gca_dataset_service .AssessDataRequest (
841
+ name = _TEST_NAME ,
842
+ batch_prediction_validation_assessment_config = gca_dataset_service .AssessDataRequest .BatchPredictionValidationAssessmentConfig (
843
+ model_name = "gemini-1.5-flash-exp" ,
844
+ ),
845
+ gemini_request_read_config = gca_dataset_service .GeminiRequestReadConfig (
846
+ template_config = template_config ._raw_gemini_template_config
847
+ ),
848
+ ),
849
+ timeout = None ,
850
+ )
851
+ assert result is None
852
+
813
853
@pytest .mark .usefixtures ("get_dataset_request_column_name_mock" )
814
854
def test_assess_tuning_validity_request_column_name (
815
855
self , assess_data_tuning_validation_mock
0 commit comments