Skip to content

Commit 6951aff

Browse files
Copilotdrewoldag
andauthored
Fix vector DB benchmarks to use model_inputs configuration (#470)
* Initial plan * Fix vector DB benchmarks to use model_inputs configuration Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com> * Fix Qdrant vector_size configuration to match actual vector length Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com> * Slight modification to the model_inputs definition. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com>
1 parent 0e88a32 commit 6951aff

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

benchmarks/vector_db_benchmarks.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ def setup(self, vector_length, vector_db_implementation):
2424

2525
self.h = Hyrax()
2626
self.h.config["general"]["results_dir"] = str(self.input_dir)
27-
self.h.config["data_set"]["name"] = "HyraxRandomDataset"
27+
self.h.config["model_inputs"] = {
28+
"data": {
29+
"dataset_class": "HyraxRandomDataset",
30+
"fields": ["image", "label", "object_id"],
31+
"primary_id_field": "object_id",
32+
}
33+
}
2834
self.h.config["model"]["name"] = "HyraxLoopback"
2935

3036
# Default inference batch size is 512, so this should result in 4 batch files
@@ -83,7 +89,13 @@ def setup(self, shard_size_limit, vector_db_implementation):
8389

8490
self.h = Hyrax()
8591
self.h.config["general"]["results_dir"] = str(self.input_dir)
86-
self.h.config["data_set"]["name"] = "HyraxRandomDataset"
92+
self.h.config["model_inputs"] = {
93+
"data": {
94+
"dataset_class": "HyraxRandomDataset",
95+
"fields": ["image", "label", "object_id"],
96+
"primary_id_field": "object_id",
97+
}
98+
}
8799
self.h.config["data_loader"]["batch_size"] = 4096
88100
self.h.config["model"]["name"] = "HyraxLoopback"
89101

@@ -102,12 +114,12 @@ def setup(self, shard_size_limit, vector_db_implementation):
102114

103115
# Get the list of dataset ids
104116
self.ds = self.h.prepare()
105-
self.data_sample = self.ds[4001]["data"]["image"].numpy()
117+
self.data_sample = self.ds[0]["data"]["image"]
106118

107119
self.h.config["vector_db"]["name"] = vector_db_implementation
108120
self.h.config["vector_db"]["chromadb"]["shard_size_limit"] = shard_size_limit
109121
# Qdrant requires the vector size in order to create its collections
110-
self.h.config["vector_db"]["qdrant"]["vector_size"] = 4096
122+
self.h.config["vector_db"]["qdrant"]["vector_size"] = self.vector_length
111123

112124
# Save inference results to vector database and create a db connection
113125
self.h.save_to_database(output_dir=Path(self.output_dir))

0 commit comments

Comments
 (0)