|
30 | 30 | from google.cloud.aiplatform import initializer
|
31 | 31 | from google.cloud.aiplatform.compat.services import (
|
32 | 32 | index_service_client,
|
| 33 | + index_service_client_v1beta1, |
33 | 34 | )
|
34 | 35 |
|
35 | 36 | from google.cloud.aiplatform.matching_engine import (
|
|
40 | 41 | index as gca_index,
|
41 | 42 | encryption_spec as gca_encryption_spec,
|
42 | 43 | index_service as gca_index_service,
|
| 44 | + index_service_v1beta1 as gca_index_service_v1beta1, |
43 | 45 | )
|
44 | 46 | import constants as test_constants
|
45 | 47 |
|
|
66 | 68 | _TEST_CONTENTS_DELTA_URI_UPDATE = "gs://contents_update"
|
67 | 69 | _TEST_IS_COMPLETE_OVERWRITE_UPDATE = True
|
68 | 70 |
|
| 71 | +_TEST_BQ_SOURCE_PATH = "bq://my-project.my-dataset.my-table" |
| 72 | +_TEST_ID_COLUMN = "id" |
| 73 | +_TEST_EMBEDDING_COLUMN = "embedding" |
| 74 | + |
| 75 | + |
69 | 76 | _TEST_INDEX_CONFIG_DIMENSIONS = 100
|
70 | 77 | _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT = 150
|
71 | 78 | _TEST_LEAF_NODE_EMBEDDING_COUNT = 123
|
@@ -208,6 +215,19 @@ def update_index_embeddings_mock():
|
208 | 215 | yield update_index_mock
|
209 | 216 |
|
210 | 217 |
|
| 218 | +@pytest.fixture |
| 219 | +def import_index_mock(): |
| 220 | + with patch.object( |
| 221 | + index_service_client_v1beta1.IndexServiceClient, "import_index" |
| 222 | + ) as import_index_mock: |
| 223 | + import_index_lro_mock = mock.Mock(operation.Operation) |
| 224 | + import_index_lro_mock.result.return_value = gca_index.Index( |
| 225 | + name=_TEST_INDEX_NAME, |
| 226 | + ) |
| 227 | + import_index_mock.return_value = import_index_lro_mock |
| 228 | + yield import_index_mock |
| 229 | + |
| 230 | + |
211 | 231 | @pytest.fixture
|
212 | 232 | def list_indexes_mock():
|
213 | 233 | with patch.object(
|
@@ -337,6 +357,42 @@ def test_update_index_embeddings(self, update_index_embeddings_mock):
|
337 | 357 | # The service only returns the name of the Index
|
338 | 358 | assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME)
|
339 | 359 |
|
| 360 | + @pytest.mark.usefixtures("get_index_mock") |
| 361 | + @pytest.mark.parametrize("is_complete_overwrite", [True, False, None]) |
| 362 | + def test_import_embeddings(self, import_index_mock, is_complete_overwrite): |
| 363 | + aiplatform.init(project=_TEST_PROJECT) |
| 364 | + |
| 365 | + my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID) |
| 366 | + |
| 367 | + config = gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig( |
| 368 | + big_query_source_config=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.BigQuerySourceConfig( |
| 369 | + table_path=_TEST_BQ_SOURCE_PATH, |
| 370 | + datapoint_field_mapping=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.DatapointFieldMapping( |
| 371 | + id_column=_TEST_ID_COLUMN, |
| 372 | + embedding_column=_TEST_EMBEDDING_COLUMN, |
| 373 | + ), |
| 374 | + ) |
| 375 | + ) |
| 376 | + |
| 377 | + updated_index = my_index.import_embeddings( |
| 378 | + config=config, |
| 379 | + is_complete_overwrite=is_complete_overwrite, |
| 380 | + import_request_timeout=_TEST_TIMEOUT, |
| 381 | + ) |
| 382 | + |
| 383 | + expected_request = gca_index_service_v1beta1.ImportIndexRequest( |
| 384 | + name=_TEST_INDEX_NAME, |
| 385 | + config=config, |
| 386 | + is_complete_overwrite=is_complete_overwrite, |
| 387 | + ) |
| 388 | + |
| 389 | + import_index_mock.assert_called_once_with( |
| 390 | + request=expected_request, |
| 391 | + metadata=_TEST_REQUEST_METADATA, |
| 392 | + timeout=_TEST_TIMEOUT, |
| 393 | + ) |
| 394 | + assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME) |
| 395 | + |
340 | 396 | def test_list_indexes(self, list_indexes_mock):
|
341 | 397 | aiplatform.init(project=_TEST_PROJECT)
|
342 | 398 |
|
|
0 commit comments