Skip to content

Commit d570fc9

Browse files
Frances Hubis Thomacopybara-github
authored andcommitted
feat: Add validation assessment for batch prediction.
PiperOrigin-RevId: 775696346
1 parent 700dd72 commit d570fc9

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

google/cloud/aiplatform/preview/datasets.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,42 @@ def assess_batch_prediction_resources(
15561556
audio_token_count=assessment_result.audio_token_count,
15571557
)
15581558

1559+
def assess_batch_prediction_validity(
1560+
self,
1561+
*,
1562+
model_name: str,
1563+
template_config: Optional[GeminiTemplateConfig] = None,
1564+
assess_request_timeout: Optional[float] = None,
1565+
) -> None:
1566+
"""Assess if the assembled dataset is valid in terms of batch prediction
1567+
for a given model. Raises an error if the dataset is invalid, otherwise
1568+
returns None.
1569+
1570+
Args:
1571+
model_name (str):
1572+
Required. The name of the model to assess the batch prediction
1573+
validity for.
1574+
dataset_usage (str):
1575+
Required. The dataset usage to assess the batch prediction
1576+
validity for.
1577+
Must be one of the following: SFT_TRAINING, SFT_VALIDATION.
1578+
template_config (GeminiTemplateConfig):
1579+
Optional. The template config used to assemble the dataset
1580+
before assessing the batch prediction validity. If not provided, the
1581+
template config attached to the dataset will be used. Required
1582+
if no template config is attached to the dataset.
1583+
assess_request_timeout (float):
1584+
Optional. The timeout for the assess batch prediction validity request.
1585+
"""
1586+
request = self._build_assess_data_request(template_config)
1587+
request.batch_prediction_validation_assessment_config = gca_dataset_service.AssessDataRequest.BatchPredictionValidationAssessmentConfig(
1588+
model_name=model_name,
1589+
)
1590+
assess_lro = self.api_client.assess_data(
1591+
request=request, timeout=assess_request_timeout
1592+
)
1593+
assess_lro.result(timeout=None)
1594+
15591595
def _build_assess_data_request(
15601596
self,
15611597
template_config: Optional[GeminiTemplateConfig] = None,

tests/unit/aiplatform/test_multimodal_datasets.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,19 @@ def assess_data_batch_prediction_resources_mock():
297297
yield assess_data_mock
298298

299299

300+
@pytest.fixture
301+
def assess_data_batch_prediction_validation_mock():
302+
with mock.patch.object(
303+
dataset_service.DatasetServiceClient, "assess_data"
304+
) as assess_data_mock:
305+
assess_data_lro_mock = mock.Mock(operation.Operation)
306+
assess_data_lro_mock.result.return_value = gca_dataset_service.AssessDataResponse(
307+
batch_prediction_validation_assessment_result=gca_dataset_service.AssessDataResponse.BatchPredictionValidationAssessmentResult()
308+
)
309+
assess_data_mock.return_value = assess_data_lro_mock
310+
yield assess_data_mock
311+
312+
300313
@pytest.fixture
301314
def assemble_data_mock():
302315
with mock.patch.object(
@@ -810,6 +823,33 @@ def test_assess_batch_prediction_resources_request_column_name(
810823
timeout=None,
811824
)
812825

826+
@pytest.mark.usefixtures("get_dataset_mock")
827+
def test_assess_batch_prediction_validity(
828+
self, assess_data_batch_prediction_validation_mock
829+
):
830+
aiplatform.init(project=_TEST_PROJECT)
831+
dataset = ummd.MultimodalDataset(dataset_name=_TEST_NAME)
832+
template_config = ummd.GeminiTemplateConfig(
833+
field_mapping={"question": "questionColumn"},
834+
)
835+
result = dataset.assess_batch_prediction_validity(
836+
model_name="gemini-1.5-flash-exp",
837+
template_config=template_config,
838+
)
839+
assess_data_batch_prediction_validation_mock.assert_called_once_with(
840+
request=gca_dataset_service.AssessDataRequest(
841+
name=_TEST_NAME,
842+
batch_prediction_validation_assessment_config=gca_dataset_service.AssessDataRequest.BatchPredictionValidationAssessmentConfig(
843+
model_name="gemini-1.5-flash-exp",
844+
),
845+
gemini_request_read_config=gca_dataset_service.GeminiRequestReadConfig(
846+
template_config=template_config._raw_gemini_template_config
847+
),
848+
),
849+
timeout=None,
850+
)
851+
assert result is None
852+
813853
@pytest.mark.usefixtures("get_dataset_request_column_name_mock")
814854
def test_assess_tuning_validity_request_column_name(
815855
self, assess_data_tuning_validation_mock

0 commit comments

Comments
 (0)