8
8
from hnswlib import Index
9
9
from tqdm import tqdm
10
10
11
+ DB_NAME = "db.gz.json"
12
+ INDEX_NAME = "index.bin"
13
+ METADATA_NAME = "metadata.json"
14
+
11
15
12
16
class Transformer (Protocol ):
13
17
def transform (self ):
@@ -30,28 +34,29 @@ def query(self, query, n=10):
30
34
"""
31
35
arr = self .encoder .transform (query )
32
36
return self .query_vector (query = arr , n = n )
33
-
37
+
34
38
def query_vector (self , query , n = 10 ):
39
+ """Query using a vector."""
35
40
labels , distances = self .index .knn_query (query , k = n )
36
41
out = [self .db [int (label )] for label in labels [0 ]]
37
42
return out , list (distances [0 ])
38
-
43
+
39
44
def walk (index , * args , n = 10 , depth = 3 , uniq_id = lambda d : d ):
40
- """Walk through the index, finding nearest neighbors of nearest neighbors.
45
+ """Walk through the index, finding nearest neighbors of nearest neighbors.
41
46
42
47
Arguments:
43
-
48
+
44
49
- args: the queries to start the walk off with
45
50
- n : number of items to return per query
46
51
- depth: how deep should the search go
47
52
- uniq_id: function that can determine the uniqness of the item (must be hashable)
48
53
"""
49
54
q = LifoQueue ()
50
55
seen = {}
51
-
56
+
52
57
for i in range (depth ):
53
58
new_args = []
54
-
59
+
55
60
for arg in args :
56
61
res , dists = index .query (arg , n = n )
57
62
for item in res :
@@ -101,15 +106,20 @@ def create_index(
101
106
if path :
102
107
path = Path (path )
103
108
path .mkdir (parents = True , exist_ok = True )
104
- if (path / "db.jsonl" ).exists ():
105
- (path / "db.jsonl" ).unlink ()
106
- srsly .write_jsonl (
107
- path / "db.jsonl" , ({"data" : item } for i , item in enumerate (data ))
108
- )
109
- index .save_index (str (path / "index.bin" ))
109
+ if (path / DB_NAME ).exists ():
110
+ (path / DB_NAME ).unlink ()
111
+ srsly .write_gzip_json (path / DB_NAME , {i : item for i , item in enumerate (data )})
112
+ index .save_index (str (path / INDEX_NAME ))
113
+ metadata = {
114
+ "created" : str (dt .datetime .now ())[:19 ],
115
+ "dim" : dim ,
116
+ "n_items" : len (data ),
117
+ "space" : space ,
118
+ "encoder" : str (encoder ),
119
+ }
110
120
srsly .write_json (
111
- path / "metadata.json" ,
112
- { "created" : str ( dt . datetime . now ())[: 19 ], "dim" : dim , "space" : space } ,
121
+ path / METADATA_NAME ,
122
+ metadata ,
113
123
)
114
124
db = {i : k for i , k in enumerate (data )}
115
125
return SimSityIndex (index = index , encoder = encoder , db = db )
@@ -118,8 +128,8 @@ def create_index(
118
128
def load_index (path , encoder ):
119
129
"""Load in a simsity index from a path. Must supply same encoder."""
120
130
path = Path (path )
121
- metadata = srsly .read_json (path / "metadata.json" )
131
+ metadata = srsly .read_json (path / METADATA_NAME )
122
132
index = Index (space = metadata ["space" ], dim = metadata ["dim" ])
123
- index .load_index (str (path / "index.bin" ))
124
- db = {i : k for i , k in enumerate (srsly .read_jsonl (path / "db.jsonl" ))}
133
+ index .load_index (str (path / INDEX_NAME ))
134
+ db = {i : k for i , k in enumerate (srsly .read_gzip_json (path / DB_NAME ))}
125
135
return SimSityIndex (index = index , encoder = encoder , db = db )
0 commit comments