Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit a8c7c57

Browse files
committed
feat: modify embedding method to use embedding class
1 parent 7649886 commit a8c7c57

File tree

2 files changed

+106
-115
lines changed

2 files changed

+106
-115
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ build-backend = "poetry.core.masonry.api"
4040
codegate = "codegate.cli:main"
4141

4242
[tool.black]
43-
line-length = 88
43+
line-length = 100
4444
target-version = ["py310"]
4545

4646
[tool.ruff]
47-
line-length = 88
47+
line-length = 100
4848
target-version = "py310"
4949
fix = true
5050

scripts/import_packages.py

Lines changed: 104 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,120 @@
1+
import asyncio
12
import json
23

34
import weaviate
45
from weaviate.classes.config import DataType, Property
56
from weaviate.embedded import EmbeddedOptions
7+
from weaviate.util import generate_uuid5
68

7-
from utils.embedding_util import generate_embeddings
8-
9-
json_files = [
10-
"data/archived.jsonl",
11-
"data/deprecated.jsonl",
12-
"data/malicious.jsonl",
13-
]
14-
15-
16-
def setup_schema(client):
17-
if not client.collections.exists("Package"):
18-
client.collections.create(
19-
"Package",
20-
properties=[
21-
Property(name="name", data_type=DataType.TEXT),
22-
Property(name="type", data_type=DataType.TEXT),
23-
Property(name="status", data_type=DataType.TEXT),
24-
Property(name="description", data_type=DataType.TEXT),
25-
],
26-
)
9+
from codegate.inference.inference_engine import LlamaCppInferenceEngine
2710

2811

29-
def generate_vector_string(package):
30-
vector_str = f"{package['name']}"
31-
# add description
32-
package_url = ""
33-
if package["type"] == "pypi":
34-
vector_str += " is a Python package available on PyPI"
35-
package_url = f"https://trustypkg.dev/pypi/{package['name']}"
36-
elif package["type"] == "npm":
37-
vector_str += " is a JavaScript package available on NPM"
38-
package_url = f"https://trustypkg.dev/npm/{package['name']}"
39-
elif package["type"] == "go":
40-
vector_str += " is a Go package. "
41-
package_url = f"https://trustypkg.dev/go/{package['name']}"
42-
elif package["type"] == "crates":
43-
vector_str += " is a Rust package available on Crates. "
44-
package_url = f"https://trustypkg.dev/crates/{package['name']}"
45-
elif package["type"] == "java":
46-
vector_str += " is a Java package. "
47-
package_url = f"https://trustypkg.dev/java/{package['name']}"
48-
49-
# add extra status
50-
if package["status"] == "archived":
51-
vector_str += f". However, this package is found to be archived and no longer \
52-
maintained. For additional information refer to {package_url}"
53-
elif package["status"] == "deprecated":
54-
vector_str += f". However, this package is found to be deprecated and no \
55-
longer recommended for use. For additional information refer to {package_url}"
56-
elif package["status"] == "malicious":
57-
vector_str += f". However, this package is found to be malicious. For \
58-
additional information refer to {package_url}"
59-
return vector_str
60-
61-
62-
def add_data(client):
63-
collection = client.collections.get("Package")
64-
65-
# read all the data from db, we will only add if there is no data, or is different
66-
existing_packages = list(collection.iterator())
67-
packages_dict = {}
68-
for package in existing_packages:
69-
key = package.properties["name"] + "/" + package.properties["type"]
70-
value = {
71-
"status": package.properties["status"],
72-
"description": package.properties["description"],
12+
class PackageImporter:
13+
def __init__(self):
14+
self.client = weaviate.WeaviateClient(
15+
embedded_options=EmbeddedOptions(
16+
persistence_data_path="./weaviate_data",
17+
grpc_port=50052
18+
)
19+
)
20+
self.json_files = [
21+
"data/archived.jsonl",
22+
"data/deprecated.jsonl",
23+
"data/malicious.jsonl",
24+
]
25+
self.client.connect()
26+
self.inference_engine = LlamaCppInferenceEngine()
27+
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
28+
29+
def setup_schema(self):
30+
if not self.client.collections.exists("Package"):
31+
self.client.collections.create(
32+
"Package",
33+
properties=[
34+
Property(name="name", data_type=DataType.TEXT),
35+
Property(name="type", data_type=DataType.TEXT),
36+
Property(name="status", data_type=DataType.TEXT),
37+
Property(name="description", data_type=DataType.TEXT),
38+
],
39+
)
40+
41+
def generate_vector_string(self, package):
42+
vector_str = f"{package['name']}"
43+
package_url = ""
44+
type_map = {
45+
"pypi": "Python package available on PyPI",
46+
"npm": "JavaScript package available on NPM",
47+
"go": "Go package",
48+
"crates": "Rust package available on Crates",
49+
"java": "Java package"
50+
}
51+
status_messages = {
52+
"archived": "However, this package is found to be archived and no longer maintained.",
53+
"deprecated": "However, this package is found to be deprecated and no longer "
54+
"recommended for use.",
55+
"malicious": "However, this package is found to be malicious."
56+
}
57+
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
58+
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
59+
60+
# Add extra status
61+
status_suffix = status_messages.get(package["status"], "")
62+
if status_suffix:
63+
vector_str += f"{status_suffix} For additional information refer to {package_url}"
64+
return vector_str
65+
66+
async def process_package(self, batch, package):
67+
vector_str = self.generate_vector_string(package)
68+
vector = await self.inference_engine.embed(self.model_path, [vector_str])
69+
# This is where the synchronous call is made
70+
batch.add_object(properties=package, vector=vector[0])
71+
72+
async def add_data(self):
73+
collection = self.client.collections.get("Package")
74+
existing_packages = list(collection.iterator())
75+
packages_dict = {
76+
f"{package.properties['name']}/{package.properties['type']}": {
77+
"status": package.properties["status"],
78+
"description": package.properties["description"]
79+
} for package in existing_packages
7380
}
74-
packages_dict[key] = value
75-
76-
for json_file in json_files:
77-
with open(json_file, "r") as f:
78-
print("Adding data from", json_file)
7981

80-
# temporary, just for testing
81-
with collection.batch.dynamic() as batch:
82+
for json_file in self.json_files:
83+
with open(json_file, "r") as f:
84+
print("Adding data from", json_file)
85+
packages_to_insert = []
8286
for line in f:
8387
package = json.loads(line)
88+
package["status"] = json_file.split('/')[-1].split('.')[0]
89+
key = f"{package['name']}/{package['type']}"
90+
91+
if key in packages_dict and packages_dict[key] == {
92+
"status": package["status"],
93+
"description": package["description"]
94+
}:
95+
print("Package already exists", key)
96+
continue
97+
98+
vector_str = self.generate_vector_string(package)
99+
vector = await self.inference_engine.embed(self.model_path, [vector_str])
100+
packages_to_insert.append((package, vector[0]))
101+
102+
# Synchronous batch insert after preparing all data
103+
with collection.batch.dynamic() as batch:
104+
for package, vector in packages_to_insert:
105+
batch.add_object(properties=package, vector=vector,
106+
uuid=generate_uuid5(package))
84107

85-
# now add the status column
86-
if "archived" in json_file:
87-
package["status"] = "archived"
88-
elif "deprecated" in json_file:
89-
package["status"] = "deprecated"
90-
elif "malicious" in json_file:
91-
package["status"] = "malicious"
92-
else:
93-
package["status"] = "unknown"
94-
95-
# check for the existing package and only add if different
96-
key = package["name"] + "/" + package["type"]
97-
if key in packages_dict:
98-
if (
99-
packages_dict[key]["status"] == package["status"]
100-
and packages_dict[key]["description"]
101-
== package["description"]
102-
):
103-
print("Package already exists", key)
104-
continue
105-
106-
# prepare the object for embedding
107-
print("Generating data for", key)
108-
vector_str = generate_vector_string(package)
109-
vector = generate_embeddings(vector_str)
110-
111-
batch.add_object(properties=package, vector=vector)
112-
113-
114-
def run_import():
115-
client = weaviate.WeaviateClient(
116-
embedded_options=EmbeddedOptions(
117-
persistence_data_path="./weaviate_data", grpc_port=50052
118-
),
119-
)
120-
with client:
121-
client.connect()
122-
print("is_ready:", client.is_ready())
123-
124-
setup_schema(client)
125-
add_data(client)
108+
async def run_import(self):
109+
self.setup_schema()
110+
await self.add_data()
126111

127112

128113
if __name__ == "__main__":
129-
run_import()
114+
importer = PackageImporter()
115+
asyncio.run(importer.run_import())
116+
try:
117+
assert importer.client.is_live()
118+
pass
119+
finally:
120+
importer.client.close()

0 commit comments

Comments
 (0)