Skip to content

Commit 19b00cc

Browse files
authored
[ENH] Implement collection apis for ChromaClient (#5653)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Implement collection level apis ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the_ [_docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 3c52c9b commit 19b00cc

File tree

5 files changed

+155
-4
lines changed

5 files changed

+155
-4
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/chroma/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ tracing.workspace = true
1919
parking_lot.workspace = true
2020

2121
chroma-api-types = { workspace = true }
22+
chroma-error = { workspace = true }
2223
chroma-types = { workspace = true }
2324

2425
[features]

rust/chroma/src/client/chroma_client.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use backon::ExponentialBuilder;
22
use backon::Retryable;
3+
use chroma_error::ChromaValidationError;
34
use parking_lot::Mutex;
45
use reqwest::Method;
56
use reqwest::StatusCode;
@@ -24,6 +25,8 @@ pub enum ChromaClientError {
2425
CouldNotResolveDatabaseId(String),
2526
#[error("Serialization/Deserialization error: {0}")]
2627
SerdeError(#[from] serde_json::Error),
28+
#[error("Validation error: {0}")]
29+
ValidationError(#[from] ChromaValidationError),
2730
}
2831

2932
#[derive(Debug)]

rust/chroma/src/collection.rs

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
use std::sync::Arc;
22

3-
use chroma_types::{Collection, InternalSchema, Metadata};
3+
use chroma_types::{
4+
plan::SearchPayload, AddCollectionRecordsRequest, AddCollectionRecordsResponse, Collection,
5+
DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, GetRequest, GetResponse,
6+
IncludeList, InternalSchema, Metadata, QueryRequest, QueryResponse, SearchRequest,
7+
SearchResponse, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse,
8+
UpdateMetadata, UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, Where,
9+
};
410
use reqwest::Method;
511
use serde::{de::DeserializeOwned, Serialize};
612

@@ -33,6 +39,146 @@ impl ChromaCollection {
3339
self.send::<(), u32>("count", Method::GET, None).await
3440
}
3541

42+
pub async fn get(
43+
&self,
44+
ids: Option<Vec<String>>,
45+
r#where: Option<Where>,
46+
limit: Option<u32>,
47+
offset: Option<u32>,
48+
include: Option<IncludeList>,
49+
) -> Result<GetResponse, ChromaClientError> {
50+
let request = GetRequest::try_new(
51+
self.collection.tenant.clone(),
52+
self.collection.database.clone(),
53+
self.collection.collection_id,
54+
ids,
55+
r#where,
56+
limit,
57+
offset.unwrap_or_default(),
58+
include.unwrap_or_else(IncludeList::default_get),
59+
)?;
60+
61+
self.send("get", Method::POST, Some(request)).await
62+
}
63+
64+
pub async fn query(
65+
&self,
66+
query_embeddings: Vec<Vec<f32>>,
67+
n_results: Option<u32>,
68+
r#where: Option<Where>,
69+
ids: Option<Vec<String>>,
70+
include: Option<IncludeList>,
71+
) -> Result<QueryResponse, ChromaClientError> {
72+
let request = QueryRequest::try_new(
73+
self.collection.tenant.clone(),
74+
self.collection.database.clone(),
75+
self.collection.collection_id,
76+
ids,
77+
r#where,
78+
query_embeddings,
79+
n_results.unwrap_or(10),
80+
include.unwrap_or_else(IncludeList::default_query),
81+
)?;
82+
83+
self.send("query", Method::POST, Some(request)).await
84+
}
85+
86+
pub async fn search(
87+
&self,
88+
searches: Vec<SearchPayload>,
89+
) -> Result<SearchResponse, ChromaClientError> {
90+
let request = SearchRequest::try_new(
91+
self.collection.tenant.clone(),
92+
self.collection.database.clone(),
93+
self.collection.collection_id,
94+
searches,
95+
)?;
96+
97+
self.send("search", Method::POST, Some(request)).await
98+
}
99+
100+
pub async fn add(
101+
&self,
102+
ids: Vec<String>,
103+
embeddings: Vec<Vec<f32>>,
104+
documents: Option<Vec<Option<String>>>,
105+
uris: Option<Vec<Option<String>>>,
106+
metadatas: Option<Vec<Option<Metadata>>>,
107+
) -> Result<AddCollectionRecordsResponse, ChromaClientError> {
108+
let request = AddCollectionRecordsRequest::try_new(
109+
self.collection.tenant.clone(),
110+
self.collection.database.clone(),
111+
self.collection.collection_id,
112+
ids,
113+
embeddings,
114+
documents,
115+
uris,
116+
metadatas,
117+
)?;
118+
119+
self.send("add", Method::POST, Some(request)).await
120+
}
121+
122+
pub async fn update(
123+
&self,
124+
ids: Vec<String>,
125+
embeddings: Option<Vec<Option<Vec<f32>>>>,
126+
documents: Option<Vec<Option<String>>>,
127+
uris: Option<Vec<Option<String>>>,
128+
metadatas: Option<Vec<Option<UpdateMetadata>>>,
129+
) -> Result<UpdateCollectionRecordsResponse, ChromaClientError> {
130+
let request = UpdateCollectionRecordsRequest::try_new(
131+
self.collection.tenant.clone(),
132+
self.collection.database.clone(),
133+
self.collection.collection_id,
134+
ids,
135+
embeddings,
136+
documents,
137+
uris,
138+
metadatas,
139+
)?;
140+
141+
self.send("update", Method::POST, Some(request)).await
142+
}
143+
144+
pub async fn upsert(
145+
&self,
146+
ids: Vec<String>,
147+
embeddings: Vec<Vec<f32>>,
148+
documents: Option<Vec<Option<String>>>,
149+
uris: Option<Vec<Option<String>>>,
150+
metadatas: Option<Vec<Option<UpdateMetadata>>>,
151+
) -> Result<UpsertCollectionRecordsResponse, ChromaClientError> {
152+
let request = UpsertCollectionRecordsRequest::try_new(
153+
self.collection.tenant.clone(),
154+
self.collection.database.clone(),
155+
self.collection.collection_id,
156+
ids,
157+
embeddings,
158+
documents,
159+
uris,
160+
metadatas,
161+
)?;
162+
163+
self.send("upsert", Method::POST, Some(request)).await
164+
}
165+
166+
pub async fn delete(
167+
&self,
168+
ids: Option<Vec<String>>,
169+
r#where: Option<Where>,
170+
) -> Result<DeleteCollectionRecordsResponse, ChromaClientError> {
171+
let request = DeleteCollectionRecordsRequest::try_new(
172+
self.collection.tenant.clone(),
173+
self.collection.database.clone(),
174+
self.collection.collection_id,
175+
ids,
176+
r#where,
177+
)?;
178+
179+
self.send("delete", Method::POST, Some(request)).await
180+
}
181+
36182
async fn send<Body: Serialize, Response: DeserializeOwned>(
37183
&self,
38184
operation: &str,

rust/types/src/api_types.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ impl UpdateCollectionRecordsRequest {
12021202
}
12031203
}
12041204

1205-
#[derive(Serialize, ToSchema)]
1205+
#[derive(Serialize, Deserialize, ToSchema)]
12061206
pub struct UpdateCollectionRecordsResponse {}
12071207

12081208
#[derive(Error, Debug)]
@@ -1266,7 +1266,7 @@ impl UpsertCollectionRecordsRequest {
12661266
}
12671267
}
12681268

1269-
#[derive(Serialize, ToSchema)]
1269+
#[derive(Serialize, Deserialize, ToSchema)]
12701270
pub struct UpsertCollectionRecordsResponse {}
12711271

12721272
#[derive(Error, Debug)]
@@ -1326,7 +1326,7 @@ impl DeleteCollectionRecordsRequest {
13261326
}
13271327
}
13281328

1329-
#[derive(Serialize, ToSchema)]
1329+
#[derive(Serialize, Deserialize, ToSchema)]
13301330
pub struct DeleteCollectionRecordsResponse {}
13311331

13321332
#[derive(Error, Debug)]

0 commit comments

Comments
 (0)