Skip to content

Commit ff1285d

Browse files
authored
[BUG] use keys instead of key_overrides in query embedding strings (#5754)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - There was a bug in sparse embedding query texts where `key_overrides` was used during query embedding instead of `keys`, as per the rename. This PR fixes that, and adds tests to ensure the happy path of embedding query strings works as intended - New functionality - ... ## Test plan _How are these changes tested?_ - [x ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 8e8fb95 commit ff1285d

File tree

4 files changed

+203
-21
lines changed

4 files changed

+203
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

chromadb/api/models/CollectionCommon.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from uuid import UUID
1818

1919
from chromadb.api.types import (
20-
EMBEDDING_KEY,
2120
URI,
2221
Schema,
2322
SparseVectorIndexConfig,
@@ -840,22 +839,22 @@ def _embed_knn_string_queries(self, knn: Any) -> Any:
840839

841840
# Handle metadata field with potential sparse embedding
842841
schema = self.schema
843-
if schema is None or key not in schema.key_overrides:
842+
if schema is None or key not in schema.keys:
844843
raise ValueError(
845844
f"Cannot embed string query for key '{key}': "
846845
f"key not found in schema. Please provide an embedded vector or "
847846
f"configure an embedding function for this key in the schema."
848847
)
849848

850-
value_type = schema.key_overrides[key]
849+
value_type = schema.keys[key]
851850

852851
# Check for sparse vector with embedding function
853852
if value_type.sparse_vector is not None:
854853
sparse_index = value_type.sparse_vector.sparse_vector_index
855854
if sparse_index is not None and sparse_index.enabled:
856-
config = sparse_index.config
857-
if config.embedding_function is not None:
858-
embedding_func = config.embedding_function
855+
sparse_config = sparse_index.config
856+
if sparse_config.embedding_function is not None:
857+
embedding_func = sparse_config.embedding_function
859858
if not isinstance(embedding_func, SparseEmbeddingFunction):
860859
embedding_func = cast(
861860
SparseEmbeddingFunction[Any], embedding_func
@@ -887,9 +886,9 @@ def _embed_knn_string_queries(self, knn: Any) -> Any:
887886
if value_type.float_list is not None:
888887
vector_index = value_type.float_list.vector_index
889888
if vector_index is not None and vector_index.enabled:
890-
config = vector_index.config
891-
if config.embedding_function is not None:
892-
embedding_func = config.embedding_function
889+
dense_config = vector_index.config
890+
if dense_config.embedding_function is not None:
891+
embedding_func = dense_config.embedding_function
893892
validate_embedding_function(embedding_func)
894893

895894
# Embed the query using the schema's embedding function

chromadb/test/api/test_schema_e2e.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from chromadb.api import ClientAPI
1+
from chromadb.api import ClientAPI, ServerAPI
22
from chromadb.api.types import (
33
Schema,
44
SparseVectorIndexConfig,
@@ -30,8 +30,10 @@
3030
register_sparse_embedding_function,
3131
)
3232
from chromadb.api.models.Collection import Collection
33+
from chromadb.api.models.CollectionCommon import CollectionCommon
3334
from chromadb.errors import InvalidArgumentError, InternalError
3435
from chromadb.execution.expression import Knn, Search
36+
from chromadb.types import Collection as CollectionModel
3537
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
3638
from uuid import uuid4
3739
import numpy as np
@@ -606,6 +608,36 @@ def test_search_embeds_string_knn_queries(
606608
assert embedded_rank.query == [11.0, 12.5]
607609

608610

611+
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
612+
def test_search_embeds_string_knn_queries_with_sparse_embedding_function(
613+
client_factories: "ClientFactories",
614+
) -> None:
615+
"""_embed_search_string_queries should embed string KNN queries using collection EF."""
616+
617+
sparse_ef = DeterministicSparseEmbeddingFunction(label="sparse")
618+
schema = Schema().create_index(
619+
key="sparse_metadata",
620+
config=SparseVectorIndexConfig(
621+
source_key="raw_text", embedding_function=sparse_ef
622+
),
623+
)
624+
collection, _ = _create_isolated_collection(client_factories, schema=schema)
625+
626+
search = Search().rank(Knn(key="sparse_metadata", query="hello world"))
627+
628+
embedded_search = collection._embed_search_string_queries(search)
629+
630+
assert isinstance(search._rank, Knn)
631+
assert search._rank.key == "sparse_metadata"
632+
assert search._rank.query == "hello world"
633+
634+
embedded_rank = embedded_search._rank
635+
assert isinstance(embedded_rank, Knn)
636+
assert embedded_rank.key == "sparse_metadata"
637+
print(embedded_rank.query)
638+
assert embedded_rank.query == SparseVector(indices=[0], values=[11.0])
639+
640+
609641
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
610642
def test_search_embeds_string_queries_in_nested_ranks(
611643
client_factories: "ClientFactories",
@@ -2048,7 +2080,9 @@ def test_sparse_vector_index_config_with_key_types(
20482080
)
20492081

20502082
# Verify sparse embeddings were generated from text_field
2051-
result2 = collection2.get(ids=["sparse-key-1", "sparse-key-2"], include=["metadatas"])
2083+
result2 = collection2.get(
2084+
ids=["sparse-key-1", "sparse-key-2"], include=["metadatas"]
2085+
)
20522086
assert result2["metadatas"] is not None
20532087
assert "sparse2" in result2["metadatas"][0]
20542088
assert "sparse2" in result2["metadatas"][1]
@@ -2072,7 +2106,9 @@ def test_schema_rejects_special_key_in_create_index() -> None:
20722106
schema.create_index(config=StringInvertedIndexConfig(), key="#custom_field")
20732107

20742108
# Test with Key object starting with #
2075-
with pytest.raises(ValueError, match="Cannot create index on special key '#embedding'"):
2109+
with pytest.raises(
2110+
ValueError, match="Cannot create index on special key '#embedding'"
2111+
):
20762112
schema.create_index(config=StringInvertedIndexConfig(), key=Key.EMBEDDING)
20772113

20782114

@@ -2084,7 +2120,9 @@ def test_schema_rejects_special_key_in_delete_index() -> None:
20842120
schema.delete_index(config=StringInvertedIndexConfig(), key="#custom_field")
20852121

20862122
# Test with Key object starting with #
2087-
with pytest.raises(ValueError, match="Cannot delete index on special key '#document'"):
2123+
with pytest.raises(
2124+
ValueError, match="Cannot delete index on special key '#document'"
2125+
):
20882126
schema.delete_index(config=StringInvertedIndexConfig(), key=Key.DOCUMENT)
20892127

20902128

@@ -2122,7 +2160,13 @@ def test_server_validates_schema_with_special_keys(
21222160
# This should be caught server-side by validate_schema()
21232161
schema = Schema()
21242162
# Bypass client-side validation by directly manipulating schema.keys
2125-
from chromadb.api.types import ValueTypes, StringValueType, StringInvertedIndexType, StringInvertedIndexConfig
2163+
from chromadb.api.types import (
2164+
ValueTypes,
2165+
StringValueType,
2166+
StringInvertedIndexType,
2167+
StringInvertedIndexConfig,
2168+
)
2169+
21262170
schema.keys["#invalid_key"] = ValueTypes(
21272171
string=StringValueType(
21282172
string_inverted_index=StringInvertedIndexType(
@@ -2138,7 +2182,9 @@ def test_server_validates_schema_with_special_keys(
21382182

21392183
# Verify server caught the invalid key
21402184
error_msg = str(exc_info.value)
2141-
assert "#" in error_msg or "key" in error_msg.lower() or "invalid" in error_msg.lower()
2185+
assert (
2186+
"#" in error_msg or "key" in error_msg.lower() or "invalid" in error_msg.lower()
2187+
)
21422188

21432189

21442190
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
@@ -2153,14 +2199,18 @@ def test_server_validates_invalid_source_key_in_sparse_vector_config(
21532199

21542200
# Create schema with invalid source_key
21552201
# Bypass client-side validation by directly creating the config
2156-
from chromadb.api.types import ValueTypes, SparseVectorValueType, SparseVectorIndexType
2202+
from chromadb.api.types import (
2203+
ValueTypes,
2204+
SparseVectorValueType,
2205+
SparseVectorIndexType,
2206+
)
21572207

21582208
schema = Schema()
21592209
# Manually construct config with invalid source_key using model_construct to bypass validation
21602210
invalid_config = SparseVectorIndexConfig.model_construct(
21612211
embedding_function=None,
21622212
source_key="#embedding", # Invalid - should be rejected
2163-
bm25=None
2213+
bm25=None,
21642214
)
21652215

21662216
schema.keys["test_sparse"] = ValueTypes(
@@ -2178,11 +2228,17 @@ def test_server_validates_invalid_source_key_in_sparse_vector_config(
21782228

21792229
# Verify server caught the invalid source_key
21802230
error_msg = str(exc_info.value)
2181-
assert "source_key" in error_msg.lower() or "#" in error_msg or "document" in error_msg.lower()
2231+
assert (
2232+
"source_key" in error_msg.lower()
2233+
or "#" in error_msg
2234+
or "document" in error_msg.lower()
2235+
)
21822236

21832237

21842238
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
2185-
def test_modify_collection_no_initial_config_creates_default_schema(client: ClientAPI) -> None:
2239+
def test_modify_collection_no_initial_config_creates_default_schema(
2240+
client: ClientAPI,
2241+
) -> None:
21862242
"""Test that modifying a collection without initial config/schema creates and updates default schema."""
21872243
collection_name = f"test_modify_no_init_{uuid4()}"
21882244

@@ -2259,7 +2315,9 @@ def test_modify_collection_with_initial_spann_schema(client: ClientAPI) -> None:
22592315

22602316

22612317
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
2262-
def test_modify_collection_updates_schema_spann_multiple_fields(client: ClientAPI) -> None:
2318+
def test_modify_collection_updates_schema_spann_multiple_fields(
2319+
client: ClientAPI,
2320+
) -> None:
22632321
"""Test that modifying multiple SPANN fields updates schema correctly."""
22642322
collection_name = f"test_modify_schema_multi_{uuid4()}"
22652323

@@ -2372,3 +2430,37 @@ def test_modify_collection_preserves_other_schema_fields(client: ClientAPI) -> N
23722430
assert refreshed_schema.defaults.boolean is not None
23732431
assert refreshed_schema.defaults.sparse_vector is not None
23742432
assert refreshed_schema.keys["#document"] is not None
2433+
2434+
2435+
def test_embeds_using_schema_embedding_function() -> None:
2436+
"""Test that embeddings are using the schema embedding function."""
2437+
schema = Schema().create_index(
2438+
config=VectorIndexConfig(embedding_function=SimpleEmbeddingFunction()),
2439+
)
2440+
2441+
collection_model = CollectionModel(
2442+
id=uuid4(),
2443+
name="schema_only_collection",
2444+
configuration_json={},
2445+
serialized_schema=schema.serialize_to_json(),
2446+
metadata=None,
2447+
dimension=4,
2448+
tenant="tenant",
2449+
database="database",
2450+
version=0,
2451+
log_position=0,
2452+
)
2453+
2454+
collection = CollectionCommon(
2455+
client=cast(ServerAPI, object()),
2456+
model=collection_model,
2457+
embedding_function=None,
2458+
)
2459+
2460+
assert collection._embedding_function is None
2461+
assert collection.configuration is not None
2462+
assert collection.configuration.get("embedding_function") is None
2463+
2464+
embeddings = collection._embed(["hello world"])
2465+
assert embeddings is not None
2466+
assert np.allclose(embeddings[0], [0.0, 1.0, 2.0, 3.0])

clients/new-js/packages/chromadb/test/search.expression.test.ts

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import type { SearchResponse, SparseVector } from "../src/api";
44
import { CollectionImpl } from "../src/collection";
55
import type { CollectionConfiguration } from "../src/collection-configuration";
66
import type { ChromaClient } from "../src/chroma-client";
7-
import type { EmbeddingFunction } from "../src/embedding-function";
7+
import type { EmbeddingFunction, SparseEmbeddingFunction } from "../src/embedding-function";
88

99
class QueryMockEmbedding implements EmbeddingFunction {
1010
public readonly name = "query_mock";
@@ -402,4 +402,94 @@ describe("search expression DSL", () => {
402402
expect(knnPayload.key).toBe("#embedding");
403403
expect(knnPayload.limit).toBe(7);
404404
});
405+
406+
test("search auto-embeds string knn queries with sparse embedding function", async () => {
407+
const queryText = "hello world";
408+
409+
class DeterministicSparseEmbedding implements SparseEmbeddingFunction {
410+
public readonly name = "deterministic_sparse";
411+
412+
constructor(private readonly label = "sparse") { }
413+
414+
async generate(texts: string[]): Promise<SparseVector[]> {
415+
return texts.map((text) => {
416+
if (text === "hello world") {
417+
return { indices: [0], values: [11.0] };
418+
}
419+
return { indices: [], values: [] };
420+
});
421+
}
422+
423+
getConfig(): Record<string, any> {
424+
return { label: this.label };
425+
}
426+
427+
static buildFromConfig(config: Record<string, any>): DeterministicSparseEmbedding {
428+
return new DeterministicSparseEmbedding(config.label);
429+
}
430+
}
431+
432+
const sparseEf = new DeterministicSparseEmbedding("sparse");
433+
const generateSpy = jest.spyOn(sparseEf, "generate");
434+
435+
const { Schema, SparseVectorIndexConfig } = await import("../src/schema");
436+
const schema = new Schema().createIndex(
437+
new SparseVectorIndexConfig({
438+
sourceKey: "raw_text",
439+
embeddingFunction: sparseEf,
440+
}),
441+
"sparse_metadata",
442+
);
443+
444+
let capturedBody: any;
445+
const mockChromaClient = {
446+
getMaxBatchSize: jest.fn<() => Promise<number>>().mockResolvedValue(1000),
447+
supportsBase64Encoding: jest.fn<() => Promise<boolean>>().mockResolvedValue(false),
448+
_path: jest.fn<() => Promise<{ path: string; tenant: string; database: string }>>().mockResolvedValue({ path: "/api/v1", tenant: "default_tenant", database: "default_database" }),
449+
};
450+
451+
const mockApiClient = {
452+
post: jest.fn().mockImplementation(async (options: any) => {
453+
capturedBody = options.body;
454+
return {
455+
data: {
456+
ids: [],
457+
documents: [],
458+
embeddings: [],
459+
metadatas: [],
460+
scores: [],
461+
select: [],
462+
} as SearchResponse,
463+
};
464+
}),
465+
};
466+
467+
const collection = new CollectionImpl({
468+
chromaClient: mockChromaClient as unknown as ChromaClient,
469+
apiClient: mockApiClient as any,
470+
id: "col-id",
471+
name: "test",
472+
configuration: {} as CollectionConfiguration,
473+
metadata: undefined,
474+
embeddingFunction: undefined,
475+
schema,
476+
});
477+
478+
await collection.search(
479+
new Search().rank(Knn({ key: "sparse_metadata", query: queryText, limit: 10 })),
480+
);
481+
482+
expect(mockApiClient.post).toHaveBeenCalledTimes(1);
483+
expect(generateSpy).toHaveBeenCalledTimes(1);
484+
expect(generateSpy).toHaveBeenCalledWith([queryText]);
485+
486+
expect(capturedBody).toBeDefined();
487+
expect(Array.isArray(capturedBody.searches)).toBe(true);
488+
expect(capturedBody.searches).toHaveLength(1);
489+
490+
const knnPayload = capturedBody.searches[0].rank.$knn;
491+
expect(knnPayload.query).toEqual({ indices: [0], values: [11.0] });
492+
expect(knnPayload.key).toBe("sparse_metadata");
493+
expect(knnPayload.limit).toBe(10);
494+
});
405495
});

0 commit comments

Comments
 (0)