Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit f7baf94

Browse files
author
Luke Hinds
authored
Merge pull request #69 from yrobla/main
Create github action for syncing and exporting vector DB
2 parents 8febe0a + a2dc54e commit f7baf94

File tree

8 files changed

+66578
-0
lines changed

8 files changed

+66578
-0
lines changed

.github/workflows/import_packages.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Sync vector DB
2+
3+
on:
4+
workflow_dispatch:
5+
6+
jobs:
7+
# This workflow contains a single job called "greet"
8+
sync_db:
9+
# The type of runner that the job will run on
10+
runs-on: ubuntu-latest
11+
12+
# Steps represent a sequence of tasks that will be executed as part of the job
13+
steps:
14+
- uses: actions/checkout@v3
15+
- uses: actions/setup-python@v5
16+
with:
17+
python-version: '3.12'
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install "."
22+
23+
- name: Install GitHub CLI
24+
run: |
25+
sudo apt-get update
26+
sudo apt-get install -y gh
27+
28+
- name: Fetch latest successful workflow run ID
29+
id: get-run-id
30+
env:
31+
GITHUB_TOKEN: ${{ github.token }}
32+
run: |
33+
workflow_id=".github/workflows/import_packages.yml"
34+
run_id=$(gh api --paginate repos/${{ github.repository }}/actions/runs --jq ".workflow_runs[] | select(.name == \"$workflow_id\" and .conclusion == \"success\") | .id" | head -n 1)
35+
echo "::set-output name=run_id::$run_id"
36+
37+
- name: Download the latest artifact
38+
env:
39+
GITHUB_TOKEN: ${{ github.token }}
40+
run: |
41+
gh run download ${{ steps.get-run-id.outputs.run_id }}
42+
43+
- name: Run sync
44+
run: |
45+
export PYTHONPATH=$PYTHONPATH:./
46+
python scripts/import_packages.py
47+
- name: 'Upload Volume'
48+
uses: actions/upload-artifact@v4
49+
with:
50+
name: database_volume
51+
path: weaviate_data
52+
retention-days: 5

data/archived.jsonl

Lines changed: 9309 additions & 0 deletions
Large diffs are not rendered by default.

data/deprecated.jsonl

Lines changed: 31572 additions & 0 deletions
Large diffs are not rendered by default.

data/malicious.jsonl

Lines changed: 25480 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@ description = "Generative AI CodeGen security gateway"
55
readme = "README.md"
66
authors = []
77
packages = [{include = "codegate", from = "src"}]
8+
requires-python = ">=3.11"
89

910
[tool.poetry.dependencies]
1011
python = ">=3.11"
1112
click = ">=8.1.0"
1213
PyYAML = ">=6.0.1"
1314
fastapi = ">=0.115.5"
1415
uvicorn = ">=0.32.1"
16+
weaviate = ">=0.1.2"
17+
weaviate-client = ">=4.9.3"
18+
torch = ">=2.5.1"
19+
transformers = ">=4.46.3"
1520

1621
litellm = "^1.52.15"
1722
[tool.poetry.group.dev.dependencies]

scripts/import_packages.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import json
2+
from utils.embedding_util import generate_embeddings
3+
import weaviate
4+
from weaviate.embedded import EmbeddedOptions
5+
from weaviate.classes.config import Property, DataType
6+
7+
8+
json_files = [
9+
'data/archived.jsonl',
10+
'data/deprecated.jsonl',
11+
'data/malicious.jsonl',
12+
]
13+
14+
15+
def setup_schema(client):
16+
if not client.collections.exists("Package"):
17+
client.collections.create(
18+
"Package",
19+
properties=[
20+
Property(name="name", data_type=DataType.TEXT),
21+
Property(name="type", data_type=DataType.TEXT),
22+
Property(name="status", data_type=DataType.TEXT),
23+
Property(name="description", data_type=DataType.TEXT),
24+
]
25+
)
26+
27+
28+
def generate_vector_string(package):
29+
vector_str = f"{package['name']}"
30+
# add description
31+
package_url = ""
32+
if package["type"] == "pypi":
33+
vector_str += " is a Python package available on PyPI"
34+
package_url = f"https://trustypkg.dev/pypi/{package['name']}"
35+
elif package["type"] == "npm":
36+
vector_str += " is a JavaScript package available on NPM"
37+
package_url = f"https://trustypkg.dev/npm/{package['name']}"
38+
elif package["type"] == "go":
39+
vector_str += " is a Go package. "
40+
package_url = f"https://trustypkg.dev/go/{package['name']}"
41+
elif package["type"] == "crates":
42+
vector_str += " is a Rust package available on Crates. "
43+
package_url = f"https://trustypkg.dev/crates/{package['name']}"
44+
elif package["type"] == "java":
45+
vector_str += " is a Java package. "
46+
package_url = f"https://trustypkg.dev/java/{package['name']}"
47+
48+
# add extra status
49+
if package["status"] == "archived":
50+
vector_str += f". However, this package is found to be archived and no longer maintained. For additional information refer to {package_url}"
51+
elif package["status"] == "deprecated":
52+
vector_str += f". However, this package is found to be deprecated and no longer recommended for use. For additional information refer to {package_url}"
53+
elif package["status"] == "malicious":
54+
vector_str += f". However, this package is found to be malicious. For additional information refer to {package_url}"
55+
return vector_str
56+
57+
58+
def add_data(client):
59+
collection = client.collections.get("Package")
60+
61+
# read all the data from db, we will only add if there is no data, or is different
62+
existing_packages = list(collection.iterator())
63+
packages_dict = {}
64+
for package in existing_packages:
65+
key = package.properties['name']+"/"+package.properties['type']
66+
value = {
67+
'status': package.properties['status'],
68+
'description': package.properties['description'],
69+
}
70+
packages_dict[key] = value
71+
72+
for json_file in json_files:
73+
with open(json_file, 'r') as f:
74+
print("Adding data from", json_file)
75+
with collection.batch.dynamic() as batch:
76+
for line in f:
77+
package = json.loads(line)
78+
79+
# now add the status column
80+
if 'archived' in json_file:
81+
package['status'] = 'archived'
82+
elif 'deprecated' in json_file:
83+
package['status'] = 'deprecated'
84+
elif 'malicious' in json_file:
85+
package['status'] = 'malicious'
86+
else:
87+
package['status'] = 'unknown'
88+
89+
# check for the existing package and only add if different
90+
key = package['name']+"/"+package['type']
91+
if key in packages_dict:
92+
if packages_dict[key]['status'] == package['status'] and packages_dict[key]['description'] == package['description']:
93+
print("Package already exists", key)
94+
continue
95+
96+
# prepare the object for embedding
97+
print("Generating data for", key)
98+
vector_str = generate_vector_string(package)
99+
vector = generate_embeddings(vector_str)
100+
101+
batch.add_object(properties=package, vector=vector)
102+
103+
104+
def run_import():
105+
client = weaviate.WeaviateClient(
106+
embedded_options=EmbeddedOptions(
107+
persistence_data_path="./weaviate_data",
108+
grpc_port=50052
109+
),
110+
)
111+
with client:
112+
client.connect()
113+
print('is_ready:', client.is_ready())
114+
115+
setup_schema(client)
116+
add_data(client)
117+
118+
119+
if __name__ == '__main__':
120+
run_import()

utils/__init__.py

Whitespace-only changes.

utils/embedding_util.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from transformers import AutoTokenizer, AutoModel
2+
import torch
3+
import torch.nn.functional as F
4+
from torch import Tensor
5+
import os
6+
import warnings
7+
8+
# The transformers library internally is creating this warning, but does not
9+
# impact our app. Safe to ignore.
10+
warnings.filterwarnings(action='ignore', category=ResourceWarning)
11+
12+
13+
# We won't have competing threads in this example app
14+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
15+
16+
17+
# Initialize tokenizer and model for GTE-base
18+
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
19+
model = AutoModel.from_pretrained('thenlper/gte-base')
20+
21+
22+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
23+
last_hidden = last_hidden_states.masked_fill(
24+
~attention_mask[..., None].bool(), 0.0)
25+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
26+
27+
28+
def generate_embeddings(text):
29+
inputs = tokenizer(text, return_tensors='pt',
30+
max_length=512, truncation=True)
31+
with torch.no_grad():
32+
outputs = model(**inputs)
33+
34+
attention_mask = inputs['attention_mask']
35+
embeddings = average_pool(outputs.last_hidden_state, attention_mask)
36+
37+
# (Optionally) normalize embeddings
38+
embeddings = F.normalize(embeddings, p=2, dim=1)
39+
40+
return embeddings.numpy().tolist()[0]

0 commit comments

Comments
 (0)