Skip to content

Commit 72bd3db

Browse files
authored
Merge pull request #39 from koaning/gzip
Use gzip and better JSON db.
2 parents fa4cef6 + 5b4a975 commit 72bd3db

File tree

3 files changed

+29
-26
lines changed

3 files changed

+29
-26
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
setup(
1919
name="simsity",
20-
version="0.5.2",
20+
version="0.5.3",
2121
author="Vincent D. Warmerdam",
2222
packages=find_packages(exclude=["notebooks", "docs"]),
23-
description="Simple Similarity Service",
23+
description="Super Simple Similarity Service",
2424
long_description=pathlib.Path("README.md").read_text(),
2525
long_description_content_type="text/markdown",
2626
url="https://github.com/koaning/simsity/",

simsity/__init__.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from hnswlib import Index
99
from tqdm import tqdm
1010

11+
DB_NAME = "db.gz.json"
12+
INDEX_NAME = "index.bin"
13+
METADATA_NAME = "metadata.json"
14+
1115

1216
class Transformer(Protocol):
1317
def transform(self):
@@ -30,28 +34,29 @@ def query(self, query, n=10):
3034
"""
3135
arr = self.encoder.transform(query)
3236
return self.query_vector(query=arr, n=n)
33-
37+
3438
def query_vector(self, query, n=10):
39+
"""Query using a vector."""
3540
labels, distances = self.index.knn_query(query, k=n)
3641
out = [self.db[int(label)] for label in labels[0]]
3742
return out, list(distances[0])
38-
43+
3944
def walk(index, *args, n=10, depth=3, uniq_id=lambda d: d):
40-
"""Walk through the index, finding nearest neighbors of nearest neighbors.
45+
"""Walk through the index, finding nearest neighbors of nearest neighbors.
4146
4247
Arguments:
43-
48+
4449
- args: the queries to start the walk off with
4550
- n : number of items to return per query
4651
- depth: how deep should the search go
4752
- uniq_id: function that can determine the uniqness of the item (must be hashable)
4853
"""
4954
q = LifoQueue()
5055
seen = {}
51-
56+
5257
for i in range(depth):
5358
new_args = []
54-
59+
5560
for arg in args:
5661
res, dists = index.query(arg, n=n)
5762
for item in res:
@@ -101,15 +106,20 @@ def create_index(
101106
if path:
102107
path = Path(path)
103108
path.mkdir(parents=True, exist_ok=True)
104-
if (path / "db.jsonl").exists():
105-
(path / "db.jsonl").unlink()
106-
srsly.write_jsonl(
107-
path / "db.jsonl", ({"data": item} for i, item in enumerate(data))
108-
)
109-
index.save_index(str(path / "index.bin"))
109+
if (path / DB_NAME).exists():
110+
(path / DB_NAME).unlink()
111+
srsly.write_gzip_json(path / DB_NAME, {i: item for i, item in enumerate(data)})
112+
index.save_index(str(path / INDEX_NAME))
113+
metadata = {
114+
"created": str(dt.datetime.now())[:19],
115+
"dim": dim,
116+
"n_items": len(data),
117+
"space": space,
118+
"encoder": str(encoder),
119+
}
110120
srsly.write_json(
111-
path / "metadata.json",
112-
{"created": str(dt.datetime.now())[:19], "dim": dim, "space": space},
121+
path / METADATA_NAME,
122+
metadata,
113123
)
114124
db = {i: k for i, k in enumerate(data)}
115125
return SimSityIndex(index=index, encoder=encoder, db=db)
@@ -118,8 +128,8 @@ def create_index(
118128
def load_index(path, encoder):
119129
"""Load in a simsity index from a path. Must supply same encoder."""
120130
path = Path(path)
121-
metadata = srsly.read_json(path / "metadata.json")
131+
metadata = srsly.read_json(path / METADATA_NAME)
122132
index = Index(space=metadata["space"], dim=metadata["dim"])
123-
index.load_index(str(path / "index.bin"))
124-
db = {i: k for i, k in enumerate(srsly.read_jsonl(path / "db.jsonl"))}
133+
index.load_index(str(path / INDEX_NAME))
134+
db = {i: k for i, k in enumerate(srsly.read_gzip_json(path / DB_NAME))}
125135
return SimSityIndex(index=index, encoder=encoder, db=db)

tests/test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010
# Create an encoder
1111
encoder = SentenceEncoder()
1212

13-
# Make an index without a path
14-
index = create_index(recipes, encoder)
15-
texts, dists = index.query("pork")
16-
for text in texts:
17-
assert "pork" in text
18-
assert index.index.element_count == 6118
19-
2013
# Make an index with a path
2114
index = create_index(recipes, encoder, path="demo")
2215
texts, dists = index.query("pork")

0 commit comments

Comments
 (0)