Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 162 additions & 11 deletions rust/frontend/src/executor/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use chroma_distance::normalize;
use chroma_error::ChromaError;
use chroma_log::{BackfillMessage, LocalCompactionManager, PurgeLogsMessage};
use chroma_segment::{
local_segment_manager::LocalSegmentManager, sqlite_metadata::SqliteMetadataReader,
local_hnsw::LocalHnswSegmentReaderError, local_segment_manager::LocalSegmentManager,
sqlite_metadata::SqliteMetadataReader,
};
use chroma_sqlite::db::SqliteDb;
use chroma_system::ComponentHandle;
Expand Down Expand Up @@ -157,19 +158,19 @@ impl LocalExecutor {
None => return empty_result,
};

let allowed_user_ids = match plan.filter {
let (allowed_user_ids, skip_missing_hnsw_ids) = match plan.filter {
Filter {
query_ids: None,
where_clause: None,
} => Vec::new(),
} => (Vec::new(), false),
Filter {
query_ids: Some(uids),
where_clause: _,
} if uids.is_empty() => return empty_result,
Filter {
query_ids: Some(uids),
where_clause: None,
} => uids,
} => (uids, false),
filter => {
let filter_plan = Get {
scan: plan.scan.clone(),
Expand Down Expand Up @@ -201,7 +202,7 @@ impl LocalExecutor {
return empty_result;
}

allowed_uids
(allowed_uids, true)
}
};

Expand All @@ -217,11 +218,14 @@ impl LocalExecutor {

let mut allowed_offset_ids = Vec::new();
for user_id in allowed_user_ids {
let offset_id = hnsw_reader
.get_offset_id_by_user_id(&user_id)
.await
.map_err(|err| ExecutorError::Internal(Box::new(err)))?;
allowed_offset_ids.push(offset_id);
match hnsw_reader.get_offset_id_by_user_id(&user_id).await {
Ok(offset_id) => allowed_offset_ids.push(offset_id),
Err(LocalHnswSegmentReaderError::IdNotFound) if skip_missing_hnsw_ids => {}
Err(err) => return Err(ExecutorError::Internal(Box::new(err))),
}
}
if skip_missing_hnsw_ids && allowed_offset_ids.is_empty() {
return empty_result;
}

let hnsw_config = collection_and_segments
Expand Down Expand Up @@ -384,10 +388,14 @@ impl Configurable<LocalExecutorConfig> for LocalExecutor {
mod tests {
use chroma_config::registry::Registry;
use chroma_config::Configurable;
use chroma_segment::local_segment_manager::LocalSegmentManager;
use chroma_sqlite::config::SqliteDBConfig;
use chroma_system::System;
use chroma_types::{
AddCollectionRecordsRequest, CreateCollectionRequest, DatabaseName, IncludeList,
QueryRequest,
InternalCollectionConfiguration, InternalHnswConfiguration, Metadata, MetadataComparison,
MetadataExpression, MetadataValue, PrimitiveOperator, QueryRequest,
VectorIndexConfiguration, Where,
};

use crate::{Frontend, FrontendConfig};
Expand Down Expand Up @@ -478,4 +486,147 @@ mod tests {
.unwrap();
assert_eq!(result.ids[0].len(), 0);
}

#[tokio::test]
async fn test_query_where_skips_hnsw_missing_candidate_after_cache_reload() {
let registry = Registry::new();
let system = System::new();
let persist_dir = tempfile::tempdir().unwrap();
let mut config = FrontendConfig::sqlite_in_memory();
config.sqlitedb = Some(SqliteDBConfig {
url: Some(
persist_dir
.path()
.join("chroma.sqlite3")
.to_string_lossy()
.into_owned(),
),
..Default::default()
});
config.segment_manager.as_mut().unwrap().persist_path =
Some(persist_dir.path().to_string_lossy().into_owned());

let config_and_system = (config, system);
let mut frontend = Frontend::try_from_config(&config_and_system, &registry)
.await
.unwrap();

let database_name =
DatabaseName::new("default_database").expect("database name should be valid");
let collection = frontend
.create_collection(
CreateCollectionRequest::try_new(
"default_tenant".to_string(),
database_name,
"test_missing_hnsw_candidate".to_string(),
None,
Some(InternalCollectionConfiguration {
vector_index: VectorIndexConfiguration::Hnsw(InternalHnswConfiguration {
sync_threshold: 3,
batch_size: 3,
..Default::default()
}),
embedding_function: None,
}),
None,
false,
)
.unwrap(),
)
.await
.unwrap();

frontend
.add(
AddCollectionRecordsRequest::try_new(
"default_tenant".to_string(),
"default_database".to_string(),
collection.collection_id,
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]],
None,
None,
None,
)
.unwrap(),
)
.await
.unwrap();

Box::pin(
frontend.query(
QueryRequest::try_new(
"default_tenant".to_string(),
"default_database".to_string(),
collection.collection_id,
None,
None,
vec![vec![0.0, 0.0]],
1,
IncludeList::default_query(),
)
.unwrap(),
),
)
.await
.unwrap();

let mut metadata = Metadata::new();
metadata.insert(
"phase".to_string(),
MetadataValue::Str("unflushed".to_string()),
);
frontend
.add(
AddCollectionRecordsRequest::try_new(
"default_tenant".to_string(),
"default_database".to_string(),
collection.collection_id,
vec!["id4".to_string()],
vec![vec![3.0, 3.0]],
None,
None,
Some(vec![Some(metadata)]),
)
.unwrap(),
)
.await
.unwrap();

registry
.get::<LocalSegmentManager>()
.unwrap()
.reset()
.await
.unwrap();

let result = Box::pin(
frontend.query(
QueryRequest::try_new(
"default_tenant".to_string(),
"default_database".to_string(),
collection.collection_id,
None,
Some(Where::Metadata(MetadataExpression {
key: "phase".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("unflushed".to_string()),
),
})),
vec![vec![3.0, 3.0]],
1,
IncludeList::default_query(),
)
.unwrap(),
),
)
.await;

assert!(
result.is_ok(),
"query(where=...) should skip candidates missing from HNSW state, got {result:?}"
);
assert_eq!(result.unwrap().ids[0], Vec::<String>::new());
}
}
Loading