Skip to content

Commit 5a0df36

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add import_embeddings method in MatchingEngineIndex resource
PiperOrigin-RevId: 775953546
1 parent 4d44f94 commit 5a0df36

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from google.auth import credentials as auth_credentials
2121
from google.protobuf import field_mask_pb2
2222
from google.cloud.aiplatform import base
23+
from google.cloud.aiplatform import compat
2324
from google.cloud.aiplatform.compat.types import (
2425
index_service as gca_index_service,
26+
index_service_v1beta1 as gca_index_service_v1beta1,
2527
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
2628
matching_engine_index as gca_matching_engine_index,
2729
encryption_spec as gca_encryption_spec,
@@ -393,6 +395,58 @@ def update_embeddings(
393395

394396
return self
395397

398+
def import_embeddings(
399+
self,
400+
config: gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig,
401+
is_complete_overwrite: Optional[bool] = None,
402+
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
403+
import_request_timeout: Optional[float] = None,
404+
) -> "MatchingEngineIndex":
405+
"""Imports embeddings from an external source, e.g., BigQuery.
406+
407+
Args:
408+
config (aiplatform.compat.types.index_service.ConnectorConfig):
409+
Required. The configuration for importing data from an external source.
410+
is_complete_overwrite (bool):
411+
Optional. If true, completely replace existing index data. Must be
412+
true for streaming update indexes.
413+
request_metadata (Sequence[Tuple[str, str]]):
414+
Optional. Strings which should be sent along with the request as metadata.
415+
import_request_timeout (float):
416+
Optional. The timeout for the request in seconds.
417+
418+
Returns:
419+
MatchingEngineIndex - The updated index resource object.
420+
"""
421+
self.wait()
422+
423+
_LOGGER.log_action_start_against_resource(
424+
"Importing embeddings",
425+
"index",
426+
self,
427+
)
428+
429+
api_v1beta1_client = self.api_client.select_version(compat.V1BETA1)
430+
import_lro = api_v1beta1_client.import_index(
431+
request=gca_index_service_v1beta1.ImportIndexRequest(
432+
name=self.resource_name,
433+
config=config,
434+
is_complete_overwrite=is_complete_overwrite,
435+
),
436+
metadata=request_metadata,
437+
timeout=import_request_timeout,
438+
)
439+
440+
_LOGGER.log_action_started_against_resource_with_lro(
441+
"Import", "index", self.__class__, import_lro
442+
)
443+
444+
self._gca_resource = import_lro.result(timeout=None)
445+
446+
_LOGGER.log_action_completed_against_resource("index", "Imported", self)
447+
448+
return self
449+
396450
@property
397451
def deployed_indexes(
398452
self,

tests/unit/aiplatform/test_matching_engine_index.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.cloud.aiplatform import initializer
3131
from google.cloud.aiplatform.compat.services import (
3232
index_service_client,
33+
index_service_client_v1beta1,
3334
)
3435

3536
from google.cloud.aiplatform.matching_engine import (
@@ -40,6 +41,7 @@
4041
index as gca_index,
4142
encryption_spec as gca_encryption_spec,
4243
index_service as gca_index_service,
44+
index_service_v1beta1 as gca_index_service_v1beta1,
4345
)
4446
import constants as test_constants
4547

@@ -66,6 +68,11 @@
6668
_TEST_CONTENTS_DELTA_URI_UPDATE = "gs://contents_update"
6769
_TEST_IS_COMPLETE_OVERWRITE_UPDATE = True
6870

71+
_TEST_BQ_SOURCE_PATH = "bq://my-project.my-dataset.my-table"
72+
_TEST_ID_COLUMN = "id"
73+
_TEST_EMBEDDING_COLUMN = "embedding"
74+
75+
6976
_TEST_INDEX_CONFIG_DIMENSIONS = 100
7077
_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT = 150
7178
_TEST_LEAF_NODE_EMBEDDING_COUNT = 123
@@ -208,6 +215,19 @@ def update_index_embeddings_mock():
208215
yield update_index_mock
209216

210217

218+
@pytest.fixture
219+
def import_index_mock():
220+
with patch.object(
221+
index_service_client_v1beta1.IndexServiceClient, "import_index"
222+
) as import_index_mock:
223+
import_index_lro_mock = mock.Mock(operation.Operation)
224+
import_index_lro_mock.result.return_value = gca_index.Index(
225+
name=_TEST_INDEX_NAME,
226+
)
227+
import_index_mock.return_value = import_index_lro_mock
228+
yield import_index_mock
229+
230+
211231
@pytest.fixture
212232
def list_indexes_mock():
213233
with patch.object(
@@ -337,6 +357,42 @@ def test_update_index_embeddings(self, update_index_embeddings_mock):
337357
# The service only returns the name of the Index
338358
assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME)
339359

360+
@pytest.mark.usefixtures("get_index_mock")
361+
@pytest.mark.parametrize("is_complete_overwrite", [True, False, None])
362+
def test_import_embeddings(self, import_index_mock, is_complete_overwrite):
363+
aiplatform.init(project=_TEST_PROJECT)
364+
365+
my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
366+
367+
config = gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig(
368+
big_query_source_config=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.BigQuerySourceConfig(
369+
table_path=_TEST_BQ_SOURCE_PATH,
370+
datapoint_field_mapping=gca_index_service_v1beta1.ImportIndexRequest.ConnectorConfig.DatapointFieldMapping(
371+
id_column=_TEST_ID_COLUMN,
372+
embedding_column=_TEST_EMBEDDING_COLUMN,
373+
),
374+
)
375+
)
376+
377+
updated_index = my_index.import_embeddings(
378+
config=config,
379+
is_complete_overwrite=is_complete_overwrite,
380+
import_request_timeout=_TEST_TIMEOUT,
381+
)
382+
383+
expected_request = gca_index_service_v1beta1.ImportIndexRequest(
384+
name=_TEST_INDEX_NAME,
385+
config=config,
386+
is_complete_overwrite=is_complete_overwrite,
387+
)
388+
389+
import_index_mock.assert_called_once_with(
390+
request=expected_request,
391+
metadata=_TEST_REQUEST_METADATA,
392+
timeout=_TEST_TIMEOUT,
393+
)
394+
assert updated_index.gca_resource == gca_index.Index(name=_TEST_INDEX_NAME)
395+
340396
def test_list_indexes(self, list_indexes_mock):
341397
aiplatform.init(project=_TEST_PROJECT)
342398

0 commit comments

Comments
 (0)