Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

feat: modify embedding method to use embedding class #101

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ build-backend = "poetry.core.masonry.api"
codegate = "codegate.cli:main"

[tool.black]
line-length = 127
line-length = 100
target-version = ["py310"]

[tool.ruff]
line-length = 88
line-length = 100
target-version = "py310"
fix = true

Expand Down
217 changes: 104 additions & 113 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,120 @@
import asyncio
import json

import weaviate
from weaviate.classes.config import DataType, Property
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

from utils.embedding_util import generate_embeddings

json_files = [
"data/archived.jsonl",
"data/deprecated.jsonl",
"data/malicious.jsonl",
]


def setup_schema(client):
if not client.collections.exists("Package"):
client.collections.create(
"Package",
properties=[
Property(name="name", data_type=DataType.TEXT),
Property(name="type", data_type=DataType.TEXT),
Property(name="status", data_type=DataType.TEXT),
Property(name="description", data_type=DataType.TEXT),
],
)
from codegate.inference.inference_engine import LlamaCppInferenceEngine


def generate_vector_string(package):
vector_str = f"{package['name']}"
# add description
package_url = ""
if package["type"] == "pypi":
vector_str += " is a Python package available on PyPI"
package_url = f"https://trustypkg.dev/pypi/{package['name']}"
elif package["type"] == "npm":
vector_str += " is a JavaScript package available on NPM"
package_url = f"https://trustypkg.dev/npm/{package['name']}"
elif package["type"] == "go":
vector_str += " is a Go package. "
package_url = f"https://trustypkg.dev/go/{package['name']}"
elif package["type"] == "crates":
vector_str += " is a Rust package available on Crates. "
package_url = f"https://trustypkg.dev/crates/{package['name']}"
elif package["type"] == "java":
vector_str += " is a Java package. "
package_url = f"https://trustypkg.dev/java/{package['name']}"

# add extra status
if package["status"] == "archived":
vector_str += f". However, this package is found to be archived and no longer \
maintained. For additional information refer to {package_url}"
elif package["status"] == "deprecated":
vector_str += f". However, this package is found to be deprecated and no \
longer recommended for use. For additional information refer to {package_url}"
elif package["status"] == "malicious":
vector_str += f". However, this package is found to be malicious. For \
additional information refer to {package_url}"
return vector_str


def add_data(client):
collection = client.collections.get("Package")

# read all the data from db, we will only add if there is no data, or is different
existing_packages = list(collection.iterator())
packages_dict = {}
for package in existing_packages:
key = package.properties["name"] + "/" + package.properties["type"]
value = {
"status": package.properties["status"],
"description": package.properties["description"],
class PackageImporter:
def __init__(self):
self.client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data",
grpc_port=50052
)
)
self.json_files = [
"data/archived.jsonl",
"data/deprecated.jsonl",
"data/malicious.jsonl",
]
self.client.connect()
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"

def setup_schema(self):
if not self.client.collections.exists("Package"):
self.client.collections.create(
"Package",
properties=[
Property(name="name", data_type=DataType.TEXT),
Property(name="type", data_type=DataType.TEXT),
Property(name="status", data_type=DataType.TEXT),
Property(name="description", data_type=DataType.TEXT),
],
)

def generate_vector_string(self, package):
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package"
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious."
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"
return vector_str

async def process_package(self, batch, package):
vector_str = self.generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
# This is where the synchronous call is made
batch.add_object(properties=package, vector=vector[0])

async def add_data(self):
collection = self.client.collections.get("Package")
existing_packages = list(collection.iterator())
packages_dict = {
f"{package.properties['name']}/{package.properties['type']}": {
"status": package.properties["status"],
"description": package.properties["description"]
} for package in existing_packages
}
packages_dict[key] = value

for json_file in json_files:
with open(json_file, "r") as f:
print("Adding data from", json_file)

# temporary, just for testing
with collection.batch.dynamic() as batch:
for json_file in self.json_files:
with open(json_file, "r") as f:
print("Adding data from", json_file)
packages_to_insert = []
for line in f:
package = json.loads(line)
package["status"] = json_file.split('/')[-1].split('.')[0]
key = f"{package['name']}/{package['type']}"

if key in packages_dict and packages_dict[key] == {
"status": package["status"],
"description": package["description"]
}:
print("Package already exists", key)
continue

vector_str = self.generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
packages_to_insert.append((package, vector[0]))

# Synchronous batch insert after preparing all data
with collection.batch.dynamic() as batch:
for package, vector in packages_to_insert:
batch.add_object(properties=package, vector=vector,
uuid=generate_uuid5(package))

# now add the status column
if "archived" in json_file:
package["status"] = "archived"
elif "deprecated" in json_file:
package["status"] = "deprecated"
elif "malicious" in json_file:
package["status"] = "malicious"
else:
package["status"] = "unknown"

# check for the existing package and only add if different
key = package["name"] + "/" + package["type"]
if key in packages_dict:
if (
packages_dict[key]["status"] == package["status"]
and packages_dict[key]["description"]
== package["description"]
):
print("Package already exists", key)
continue

# prepare the object for embedding
print("Generating data for", key)
vector_str = generate_vector_string(package)
vector = generate_embeddings(vector_str)

batch.add_object(properties=package, vector=vector)


def run_import():
client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data", grpc_port=50052
),
)
with client:
client.connect()
print("is_ready:", client.is_ready())

setup_schema(client)
add_data(client)
async def run_import(self):
self.setup_schema()
await self.add_data()


if __name__ == "__main__":
run_import()
importer = PackageImporter()
asyncio.run(importer.run_import())
try:
assert importer.client.is_live()
pass
finally:
importer.client.close()