1- from chromadb .api import ClientAPI
1+ from chromadb .api import ClientAPI , ServerAPI
22from chromadb .api .types import (
33 Schema ,
44 SparseVectorIndexConfig ,
3030 register_sparse_embedding_function ,
3131)
3232from chromadb .api .models .Collection import Collection
33+ from chromadb .api .models .CollectionCommon import CollectionCommon
3334from chromadb .errors import InvalidArgumentError , InternalError
3435from chromadb .execution .expression import Knn , Search
36+ from chromadb .types import Collection as CollectionModel
3537from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , cast
3638from uuid import uuid4
3739import numpy as np
@@ -606,6 +608,36 @@ def test_search_embeds_string_knn_queries(
606608 assert embedded_rank .query == [11.0 , 12.5 ]
607609
608610
611+ @pytest .mark .skipif (is_spann_disabled_mode , reason = skip_reason_spann_disabled )
612+ def test_search_embeds_string_knn_queries_with_sparse_embedding_function (
613+ client_factories : "ClientFactories" ,
614+ ) -> None :
615+ """_embed_search_string_queries should embed string KNN queries using collection EF."""
616+
617+ sparse_ef = DeterministicSparseEmbeddingFunction (label = "sparse" )
618+ schema = Schema ().create_index (
619+ key = "sparse_metadata" ,
620+ config = SparseVectorIndexConfig (
621+ source_key = "raw_text" , embedding_function = sparse_ef
622+ ),
623+ )
624+ collection , _ = _create_isolated_collection (client_factories , schema = schema )
625+
626+ search = Search ().rank (Knn (key = "sparse_metadata" , query = "hello world" ))
627+
628+ embedded_search = collection ._embed_search_string_queries (search )
629+
630+ assert isinstance (search ._rank , Knn )
631+ assert search ._rank .key == "sparse_metadata"
632+ assert search ._rank .query == "hello world"
633+
634+ embedded_rank = embedded_search ._rank
635+ assert isinstance (embedded_rank , Knn )
636+ assert embedded_rank .key == "sparse_metadata"
637+ print (embedded_rank .query )
638+ assert embedded_rank .query == SparseVector (indices = [0 ], values = [11.0 ])
639+
640+
609641@pytest .mark .skipif (is_spann_disabled_mode , reason = skip_reason_spann_disabled )
610642def test_search_embeds_string_queries_in_nested_ranks (
611643 client_factories : "ClientFactories" ,
@@ -2048,7 +2080,9 @@ def test_sparse_vector_index_config_with_key_types(
20482080 )
20492081
20502082 # Verify sparse embeddings were generated from text_field
2051- result2 = collection2 .get (ids = ["sparse-key-1" , "sparse-key-2" ], include = ["metadatas" ])
2083+ result2 = collection2 .get (
2084+ ids = ["sparse-key-1" , "sparse-key-2" ], include = ["metadatas" ]
2085+ )
20522086 assert result2 ["metadatas" ] is not None
20532087 assert "sparse2" in result2 ["metadatas" ][0 ]
20542088 assert "sparse2" in result2 ["metadatas" ][1 ]
@@ -2072,7 +2106,9 @@ def test_schema_rejects_special_key_in_create_index() -> None:
20722106 schema .create_index (config = StringInvertedIndexConfig (), key = "#custom_field" )
20732107
20742108 # Test with Key object starting with #
2075- with pytest .raises (ValueError , match = "Cannot create index on special key '#embedding'" ):
2109+ with pytest .raises (
2110+ ValueError , match = "Cannot create index on special key '#embedding'"
2111+ ):
20762112 schema .create_index (config = StringInvertedIndexConfig (), key = Key .EMBEDDING )
20772113
20782114
@@ -2084,7 +2120,9 @@ def test_schema_rejects_special_key_in_delete_index() -> None:
20842120 schema .delete_index (config = StringInvertedIndexConfig (), key = "#custom_field" )
20852121
20862122 # Test with Key object starting with #
2087- with pytest .raises (ValueError , match = "Cannot delete index on special key '#document'" ):
2123+ with pytest .raises (
2124+ ValueError , match = "Cannot delete index on special key '#document'"
2125+ ):
20882126 schema .delete_index (config = StringInvertedIndexConfig (), key = Key .DOCUMENT )
20892127
20902128
@@ -2122,7 +2160,13 @@ def test_server_validates_schema_with_special_keys(
21222160 # This should be caught server-side by validate_schema()
21232161 schema = Schema ()
21242162 # Bypass client-side validation by directly manipulating schema.keys
2125- from chromadb .api .types import ValueTypes , StringValueType , StringInvertedIndexType , StringInvertedIndexConfig
2163+ from chromadb .api .types import (
2164+ ValueTypes ,
2165+ StringValueType ,
2166+ StringInvertedIndexType ,
2167+ StringInvertedIndexConfig ,
2168+ )
2169+
21262170 schema .keys ["#invalid_key" ] = ValueTypes (
21272171 string = StringValueType (
21282172 string_inverted_index = StringInvertedIndexType (
@@ -2138,7 +2182,9 @@ def test_server_validates_schema_with_special_keys(
21382182
21392183 # Verify server caught the invalid key
21402184 error_msg = str (exc_info .value )
2141- assert "#" in error_msg or "key" in error_msg .lower () or "invalid" in error_msg .lower ()
2185+ assert (
2186+ "#" in error_msg or "key" in error_msg .lower () or "invalid" in error_msg .lower ()
2187+ )
21422188
21432189
21442190@pytest .mark .skipif (is_spann_disabled_mode , reason = skip_reason_spann_disabled )
@@ -2153,14 +2199,18 @@ def test_server_validates_invalid_source_key_in_sparse_vector_config(
21532199
21542200 # Create schema with invalid source_key
21552201 # Bypass client-side validation by directly creating the config
2156- from chromadb .api .types import ValueTypes , SparseVectorValueType , SparseVectorIndexType
2202+ from chromadb .api .types import (
2203+ ValueTypes ,
2204+ SparseVectorValueType ,
2205+ SparseVectorIndexType ,
2206+ )
21572207
21582208 schema = Schema ()
21592209 # Manually construct config with invalid source_key using model_construct to bypass validation
21602210 invalid_config = SparseVectorIndexConfig .model_construct (
21612211 embedding_function = None ,
21622212 source_key = "#embedding" , # Invalid - should be rejected
2163- bm25 = None
2213+ bm25 = None ,
21642214 )
21652215
21662216 schema .keys ["test_sparse" ] = ValueTypes (
@@ -2178,11 +2228,17 @@ def test_server_validates_invalid_source_key_in_sparse_vector_config(
21782228
21792229 # Verify server caught the invalid source_key
21802230 error_msg = str (exc_info .value )
2181- assert "source_key" in error_msg .lower () or "#" in error_msg or "document" in error_msg .lower ()
2231+ assert (
2232+ "source_key" in error_msg .lower ()
2233+ or "#" in error_msg
2234+ or "document" in error_msg .lower ()
2235+ )
21822236
21832237
21842238@pytest .mark .skipif (is_spann_disabled_mode , reason = skip_reason_spann_disabled )
2185- def test_modify_collection_no_initial_config_creates_default_schema (client : ClientAPI ) -> None :
2239+ def test_modify_collection_no_initial_config_creates_default_schema (
2240+ client : ClientAPI ,
2241+ ) -> None :
21862242 """Test that modifying a collection without initial config/schema creates and updates default schema."""
21872243 collection_name = f"test_modify_no_init_{ uuid4 ()} "
21882244
@@ -2259,7 +2315,9 @@ def test_modify_collection_with_initial_spann_schema(client: ClientAPI) -> None:
22592315
22602316
22612317@pytest .mark .skipif (is_spann_disabled_mode , reason = skip_reason_spann_disabled )
2262- def test_modify_collection_updates_schema_spann_multiple_fields (client : ClientAPI ) -> None :
2318+ def test_modify_collection_updates_schema_spann_multiple_fields (
2319+ client : ClientAPI ,
2320+ ) -> None :
22632321 """Test that modifying multiple SPANN fields updates schema correctly."""
22642322 collection_name = f"test_modify_schema_multi_{ uuid4 ()} "
22652323
@@ -2372,3 +2430,37 @@ def test_modify_collection_preserves_other_schema_fields(client: ClientAPI) -> N
23722430 assert refreshed_schema .defaults .boolean is not None
23732431 assert refreshed_schema .defaults .sparse_vector is not None
23742432 assert refreshed_schema .keys ["#document" ] is not None
2433+
2434+
2435+ def test_embeds_using_schema_embedding_function () -> None :
2436+ """Test that embeddings are using the schema embedding function."""
2437+ schema = Schema ().create_index (
2438+ config = VectorIndexConfig (embedding_function = SimpleEmbeddingFunction ()),
2439+ )
2440+
2441+ collection_model = CollectionModel (
2442+ id = uuid4 (),
2443+ name = "schema_only_collection" ,
2444+ configuration_json = {},
2445+ serialized_schema = schema .serialize_to_json (),
2446+ metadata = None ,
2447+ dimension = 4 ,
2448+ tenant = "tenant" ,
2449+ database = "database" ,
2450+ version = 0 ,
2451+ log_position = 0 ,
2452+ )
2453+
2454+ collection = CollectionCommon (
2455+ client = cast (ServerAPI , object ()),
2456+ model = collection_model ,
2457+ embedding_function = None ,
2458+ )
2459+
2460+ assert collection ._embedding_function is None
2461+ assert collection .configuration is not None
2462+ assert collection .configuration .get ("embedding_function" ) is None
2463+
2464+ embeddings = collection ._embed (["hello world" ])
2465+ assert embeddings is not None
2466+ assert np .allclose (embeddings [0 ], [0.0 , 1.0 , 2.0 , 3.0 ])
0 commit comments