Skip to content

Commit 6339fcb

Browse files
authored
Merge pull request #1425 from weaviate/wes_suport
Weaviate embedding service support
2 parents 7f36591 + 75ffad5 commit 6339fcb

File tree

7 files changed

+118
-56
lines changed

7 files changed

+118
-56
lines changed

test/collection/test_config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,22 @@ def test_basic_config():
399399
}
400400
},
401401
),
402+
(
403+
Configure.Vectorizer.text2vec_weaviate(
404+
vectorize_collection_name=False,
405+
model="Snowflake/snowflake-arctic-embed-m-v1.5",
406+
base_url="https://api.embedding.weaviate.io",
407+
dimensions=768,
408+
),
409+
{
410+
"text2vec-weaviate": {
411+
"vectorizeClassName": False,
412+
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
413+
"baseURL": "https://api.embedding.weaviate.io",
414+
"dimensions": 768,
415+
}
416+
},
417+
),
402418
(
403419
Configure.Vectorizer.img2vec_neural(
404420
image_fields=["test"],
@@ -1495,6 +1511,29 @@ def test_vector_config_flat_pq() -> None:
14951511
}
14961512
},
14971513
),
1514+
(
1515+
[
1516+
Configure.NamedVectors.text2vec_weaviate(
1517+
name="test",
1518+
source_properties=["prop"],
1519+
base_url="https://api.embedding.weaviate.io",
1520+
dimensions=768,
1521+
)
1522+
],
1523+
{
1524+
"test": {
1525+
"vectorizer": {
1526+
"text2vec-weaviate": {
1527+
"properties": ["prop"],
1528+
"vectorizeClassName": True,
1529+
"baseURL": "https://api.embedding.weaviate.io",
1530+
"dimensions": 768,
1531+
}
1532+
},
1533+
"vectorIndexType": "hnsw",
1534+
}
1535+
},
1536+
),
14981537
(
14991538
[
15001539
Configure.NamedVectors.img2vec_neural(

weaviate/collections/batch/grpc_batch_delete.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from grpc.aio import AioRpcError # type: ignore
44

5-
65
from weaviate.collections.classes.batch import (
76
DeleteManyObject,
87
DeleteManyReturn,
@@ -26,7 +25,6 @@ def __init__(self, connection: ConnectionV4, consistency_level: Optional[Consist
2625
async def batch_delete(
2726
self, name: str, filters: _Filters, verbose: bool, dry_run: bool, tenant: Optional[str]
2827
) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]:
29-
metadata = self._get_metadata()
3028
try:
3129
assert self._connection.grpc_stub is not None
3230
res = await self._connection.grpc_stub.BatchDelete(
@@ -38,7 +36,7 @@ async def batch_delete(
3836
tenant=tenant,
3937
filters=_FilterToGRPC.convert(filters),
4038
),
41-
metadata=metadata,
39+
metadata=self._connection.grpc_headers(),
4240
timeout=self._connection.timeout_config.insert,
4341
)
4442
res = cast(batch_delete_pb2.BatchDeleteReply, res)

weaviate/collections/batch/grpc_batch_objects.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
import uuid as uuid_package
55
from typing import Any, Dict, List, Optional, Union, cast
66

7-
from grpc.aio import AioRpcError # type: ignore
87
from google.protobuf.struct_pb2 import Struct
8+
from grpc.aio import AioRpcError # type: ignore
99

1010
from weaviate.collections.classes.batch import (
1111
ErrorObject,
1212
_BatchObject,
1313
BatchObjectReturn,
1414
)
1515
from weaviate.collections.classes.config import ConsistencyLevel
16-
from weaviate.collections.classes.types import GeoCoordinate, PhoneNumber
1716
from weaviate.collections.classes.internal import ReferenceToMulti, ReferenceInputs
17+
from weaviate.collections.classes.types import GeoCoordinate, PhoneNumber
1818
from weaviate.collections.grpc.shared import _BaseGRPC
1919
from weaviate.connect import ConnectionV4
2020
from weaviate.exceptions import (
@@ -135,15 +135,14 @@ async def objects(
135135
async def __send_batch(
136136
self, batch: List[batch_pb2.BatchObject], timeout: Union[int, float]
137137
) -> Dict[int, str]:
138-
metadata = self._get_metadata()
139138
try:
140139
assert self._connection.grpc_stub is not None
141140
res = await self._connection.grpc_stub.BatchObjects(
142141
batch_pb2.BatchObjectsRequest(
143142
objects=batch,
144143
consistency_level=self._consistency_level,
145144
),
146-
metadata=metadata,
145+
metadata=self._connection.grpc_headers(),
147146
timeout=timeout,
148147
)
149148
res = cast(batch_pb2.BatchObjectsReply, res)

weaviate/collections/classes/config_named_vectors.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
_Text2VecVoyageConfig,
5555
_Multi2VecCohereConfig,
5656
_Multi2VecJinaConfig,
57+
_Text2VecWeaviateConfig,
58+
WeaviateModel,
5759
)
5860
from ...warnings import _Warnings
5961

@@ -1221,6 +1223,29 @@ def text2vec_voyageai(
12211223
vector_index_config=vector_index_config,
12221224
)
12231225

1226+
@staticmethod
1227+
def text2vec_weaviate(
1228+
name: str,
1229+
*,
1230+
source_properties: Optional[List[str]] = None,
1231+
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
1232+
vectorize_collection_name: bool = True,
1233+
model: Optional[Union[WeaviateModel, str]] = None,
1234+
base_url: Optional[str] = None,
1235+
dimensions: Optional[int] = None,
1236+
) -> _NamedVectorConfigCreate:
1237+
return _NamedVectorConfigCreate(
1238+
name=name,
1239+
source_properties=source_properties,
1240+
vectorizer=_Text2VecWeaviateConfig(
1241+
model=model,
1242+
vectorizeClassName=vectorize_collection_name,
1243+
baseURL=base_url,
1244+
dimensions=dimensions,
1245+
),
1246+
vector_index_config=vector_index_config,
1247+
)
1248+
12241249

12251250
class _NamedVectorsUpdate:
12261251
@staticmethod

weaviate/collections/classes/config_vectorizers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"bedrock",
6565
"sagemaker",
6666
]
67+
WeaviateModel: TypeAlias = Literal["Snowflake/snowflake-arctic-embed-m-v1.5"]
6768

6869

6970
class Vectorizers(str, Enum):
@@ -95,6 +96,8 @@ class Vectorizers(str, Enum):
9596
Weaviate module backed by Jina AI text-based embedding models.
9697
`TEXT2VEC_VOYAGEAI`
9798
Weaviate module backed by Voyage AI text-based embedding models.
99+
`TEXT2VEC_WEAVIATE`
100+
Weaviate module backed by Weaviate's self-hosted text-based embedding models.
98101
`IMG2VEC_NEURAL`
99102
Weaviate module backed by a ResNet-50 neural network for images.
100103
`MULTI2VEC_CLIP`
@@ -121,6 +124,7 @@ class Vectorizers(str, Enum):
121124
TEXT2VEC_TRANSFORMERS = "text2vec-transformers"
122125
TEXT2VEC_JINAAI = "text2vec-jinaai"
123126
TEXT2VEC_VOYAGEAI = "text2vec-voyageai"
127+
TEXT2VEC_WEAVIATE = "text2vec-weaviate"
124128
IMG2VEC_NEURAL = "img2vec-neural"
125129
MULTI2VEC_CLIP = "multi2vec-clip"
126130
MULTI2VEC_COHERE = "multi2vec-cohere"
@@ -343,6 +347,16 @@ class _Text2VecVoyageConfig(_VectorizerConfigCreate):
343347
vectorizeClassName: bool
344348

345349

350+
class _Text2VecWeaviateConfig(_VectorizerConfigCreate):
351+
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
352+
default=Vectorizers.TEXT2VEC_WEAVIATE, frozen=True, exclude=True
353+
)
354+
model: Optional[str]
355+
baseURL: Optional[str]
356+
vectorizeClassName: bool
357+
dimensions: Optional[int]
358+
359+
346360
class _Text2VecOllamaConfig(_VectorizerConfigCreate):
347361
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
348362
default=Vectorizers.TEXT2VEC_OLLAMA, frozen=True, exclude=True
@@ -1290,3 +1304,19 @@ def text2vec_voyageai(
12901304
truncate=truncate,
12911305
vectorizeClassName=vectorize_collection_name,
12921306
)
1307+
1308+
@staticmethod
1309+
def text2vec_weaviate(
1310+
*,
1311+
model: Optional[Union[WeaviateModel, str]] = None,
1312+
base_url: Optional[str] = None,
1313+
vectorize_collection_name: bool = True,
1314+
dimensions: Optional[int] = None,
1315+
) -> _VectorizerConfigCreate:
1316+
"""TODO: add docstrings when the documentation is available."""
1317+
return _Text2VecWeaviateConfig(
1318+
model=model,
1319+
baseURL=base_url,
1320+
vectorizeClassName=vectorize_collection_name,
1321+
dimensions=dimensions,
1322+
)

weaviate/collections/grpc/shared.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, List
1+
from typing import Optional
22

33
from weaviate.collections.classes.config import ConsistencyLevel
44
from weaviate.connect import ConnectionV4
@@ -14,24 +14,6 @@ def __init__(
1414
self._connection = connection
1515
self._consistency_level = self._get_consistency_level(consistency_level)
1616

17-
def _get_metadata(self) -> Optional[Tuple[Tuple[str, str], ...]]:
18-
metadata: Optional[Tuple[Tuple[str, str], ...]] = None
19-
access_token = self._connection.get_current_bearer_token()
20-
21-
metadata_list: List[Tuple[str, str]] = []
22-
if len(access_token) > 0:
23-
metadata_list.append(("authorization", access_token))
24-
25-
if len(self._connection.additional_headers):
26-
for key, val in self._connection.additional_headers.items():
27-
if val is not None:
28-
metadata_list.append((key.lower(), val))
29-
30-
if len(metadata_list) > 0:
31-
metadata = tuple(metadata_list)
32-
33-
return metadata
34-
3517
@staticmethod
3618
def _get_consistency_level(
3719
consistency_level: Optional[ConsistencyLevel],

weaviate/connect/v4.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
self.__loop = loop
122122

123123
self._headers = {"content-type": "application/json"}
124+
self.__add_weaviate_embedding_service_header(connection_params.http.host)
124125
if additional_headers is not None:
125126
_validate_input(_ValidateArgument([dict], "additional_headers", additional_headers))
126127
self.__additional_headers = additional_headers
@@ -141,6 +142,12 @@ def __init__(
141142

142143
self._prepare_grpc_headers()
143144

145+
def __add_weaviate_embedding_service_header(self, wcd_host: str) -> None:
146+
if not is_weaviate_domain(wcd_host) or not isinstance(self._auth, AuthApiKey):
147+
return
148+
self._headers["X-Weaviate-Api-Key"] = self._auth.api_key
149+
self._headers["X-Weaviate-Cluster-URL"] = "https://" + wcd_host
150+
144151
async def connect(self, skip_init_checks: bool) -> None:
145152
self.__connected = True
146153

@@ -655,7 +662,17 @@ def _prepare_grpc_headers(self) -> None:
655662

656663
if self._auth is not None:
657664
if isinstance(self._auth, AuthApiKey):
658-
self.__metadata_list.append(("authorization", self._auth.api_key))
665+
if (
666+
"X-Weaviate-Cluster-URL" in self._headers
667+
and "X-Weaviate-Api-Key" in self._headers
668+
):
669+
self.__metadata_list.append(
670+
("x-weaviate-cluster-url", self._headers["X-Weaviate-Cluster-URL"])
671+
)
672+
self.__metadata_list.append(
673+
("x-weaviate-api-key", self._headers["X-Weaviate-Api-Key"])
674+
)
675+
self.__metadata_list.append(("authorization", "Bearer " + self._auth.api_key))
659676
else:
660677
self.__metadata_list.append(
661678
("authorization", "dummy_will_be_refreshed_for_each_call")
@@ -667,7 +684,7 @@ def _prepare_grpc_headers(self) -> None:
667684
self.__grpc_headers = None
668685

669686
def grpc_headers(self) -> Optional[Tuple[Tuple[str, str], ...]]:
670-
if self._auth is None or not isinstance(self._auth, AuthApiKey):
687+
if self._auth is None or isinstance(self._auth, AuthApiKey):
671688
return self.__grpc_headers
672689

673690
assert self.__grpc_headers is not None
@@ -676,34 +693,6 @@ def grpc_headers(self) -> Optional[Tuple[Tuple[str, str], ...]]:
676693
self.__metadata_list[len(self.__metadata_list) - 1] = ("authorization", access_token)
677694
return tuple(self.__metadata_list)
678695

679-
# async def _ping_grpc(self) -> None:
680-
# """Performs a grpc health check and raises WeaviateGRPCUnavailableError if not."""
681-
# if not self.is_connected():
682-
# raise WeaviateClosedClientError()
683-
# assert self._grpc_channel is not None
684-
# try:
685-
# request = self._grpc_channel.request(
686-
# "/grpc.health.v1.Health/Check",
687-
# Cardinality.UNARY_UNARY,
688-
# health_pb2.HealthCheckRequest,
689-
# health_pb2.HealthCheckResponse,
690-
# timeout=self.timeout_config.init,
691-
# )
692-
# async with request as stream:
693-
# await stream.send_message(health_pb2.HealthCheckRequest())
694-
# res = await stream.recv_message()
695-
# await stream.end()
696-
# if res is None or res.status != health_pb2.HealthCheckResponse.SERVING:
697-
# self.__connected = False
698-
# raise WeaviateGRPCUnavailableError(
699-
# f"v{self.server_version}", self._connection_params._grpc_address
700-
# )
701-
# except Exception as e:
702-
# self.__connected = False
703-
# raise WeaviateGRPCUnavailableError(
704-
# f"v{self.server_version}", self._connection_params._grpc_address
705-
# ) from e
706-
707696
async def _ping_grpc(self) -> None:
708697
"""Performs a grpc health check and raises WeaviateGRPCUnavailableError if not."""
709698
if not self.is_connected():

0 commit comments

Comments
 (0)