Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 01e2a75

Browse files
committed
Run filter using the specified packages/ecosystem
1 parent 241a9f5 commit 01e2a75

File tree

4 files changed

+61
-50
lines changed

4 files changed

+61
-50
lines changed

src/codegate/llm_utils/extractor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def extract_packages(
3030
system_prompt = Config.get_config().prompts.lookup_packages
3131

3232
result = await LLMClient.complete(
33-
content=content,
33+
content=content.lower(),
3434
system_prompt=system_prompt,
3535
provider=provider,
3636
model=model,
@@ -41,6 +41,9 @@ async def extract_packages(
4141

4242
# Handle both formats: {"packages": [...]} and direct list [...]
4343
packages = result if isinstance(result, list) else result.get("packages", [])
44+
45+
# Filter packages based on the content
46+
packages = [package.lower() for package in packages if package.lower() in content]
4447
logger.info(f"Extracted packages: {packages}")
4548
return packages
4649

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def name(self) -> str:
2929
"""
3030
return "codegate-context-retriever"
3131

32-
async def get_objects_from_search(
33-
self, search: str, ecosystem, packages: list[str] = None
32+
async def get_objects_from_db(
33+
self, ecosystem, packages: list[str] = None
3434
) -> list[object]:
3535
storage_engine = StorageEngine()
3636
objects = await storage_engine.search(
37-
search, distance=0.8, ecosystem=ecosystem, packages=packages
37+
distance=0.8, ecosystem=ecosystem, packages=packages
3838
)
3939
return objects
4040

@@ -103,39 +103,25 @@ async def process(
103103
# Extract packages from the user message
104104
ecosystem = await self.__lookup_ecosystem(user_messages, context)
105105
packages = await self.__lookup_packages(user_messages, context)
106-
packages = [pkg.lower() for pkg in packages]
107106

108-
# If user message does not reference any packages, then just return
109-
if len(packages) == 0:
110-
return PipelineResult(request=request)
111-
112-
# Look for matches in vector DB using list of packages as filter
113-
searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages)
107+
context_str = "CodeGate did not find any malicious or archived packages."
114108

115-
logger.info(
116-
f"Found {len(searched_objects)} matches in the database",
117-
searched_objects=searched_objects,
118-
)
109+
if len(packages) > 0:
110+
# Look for matches in DB using packages and ecosystem
111+
searched_objects = await self.get_objects_from_db(ecosystem, packages)
119112

120-
# Remove searched objects that are not in packages. This is needed
121-
# since Weaviate performs substring match in the filter.
122-
updated_searched_objects = []
123-
for searched_object in searched_objects:
124-
if searched_object.properties["name"].lower() in packages:
125-
updated_searched_objects.append(searched_object)
126-
searched_objects = updated_searched_objects
113+
logger.info(
114+
f"Found {len(searched_objects)} matches in the database",
115+
searched_objects=searched_objects,
116+
)
127117

128-
# Generate context string using the searched objects
129-
logger.info(f"Adding {len(searched_objects)} packages to the context")
118+
# Generate context string using the searched objects
119+
logger.info(f"Adding {len(searched_objects)} packages to the context")
130120

131-
if len(searched_objects) > 0:
132-
context_str = self.generate_context_str(searched_objects, context)
133-
else:
134-
context_str = "CodeGate did not find any malicious or archived packages."
121+
if len(searched_objects) > 0:
122+
context_str = self.generate_context_str(searched_objects, context)
135123

136124
last_user_idx = self.get_last_user_message_idx(request)
137-
if last_user_idx == -1:
138-
return PipelineResult(request=request, context=context)
139125

140126
# Make a copy of the request
141127
new_request = request.copy()

src/codegate/storage/storage_engine.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
121121
return []
122122

123123
try:
124-
125124
packages = self.weaviate_client.collections.get("Package")
126125
response = packages.query.fetch_objects(
127126
filters=Filter.by_property(name).contains_any(properties),
@@ -145,10 +144,19 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
145144
return []
146145

147146
async def search(
148-
self, query: str, limit=5, distance=0.3, ecosystem=None, packages=None
147+
self,
148+
query: str = None,
149+
ecosystem: str = None,
150+
packages: List[str] = None,
151+
limit: int = 5,
152+
distance: float = 0.3,
149153
) -> list[object]:
150154
"""
151-
Search the 'Package' collection based on a query string.
155+
Search the 'Package' collection based on a query string, ecosystem and packages.
156+
If packages and ecosystem are both not none, then filter the objects using them.
157+
If packages is not none and ecosystem is none, then filter the objects using
158+
package names.
159+
If packages is none, then perform vector search.
152160
153161
Args:
154162
query (str): The text query for which to search.
@@ -160,26 +168,40 @@ async def search(
160168
Returns:
161169
list: A list of matching results with their properties and distances.
162170
"""
163-
# Generate the vector for the query
164-
query_vector = await self.inference_engine.embed(self.model_path, [query])
165171

166-
# Perform the vector search
167172
try:
168173
collection = self.weaviate_client.collections.get("Package")
169-
if packages:
170-
# filter by packages and ecosystem if present
171-
filters = []
172-
if ecosystem and ecosystem in VALID_ECOSYSTEMS:
173-
filters.append(wvc.query.Filter.by_property("type").equal(ecosystem))
174-
filters.append(wvc.query.Filter.by_property("name").contains_any(packages))
175-
response = collection.query.near_vector(
176-
query_vector[0],
177-
limit=limit,
178-
distance=distance,
179-
filters=wvc.query.Filter.all_of(filters),
180-
return_metadata=MetadataQuery(distance=True),
174+
175+
response = None
176+
if packages and ecosystem and ecosystem in VALID_ECOSYSTEMS:
177+
response = collection.query.fetch_objects(
178+
filters=wvc.query.Filter.all_of([
179+
wvc.query.Filter.by_property("name").contains_any(packages),
180+
wvc.query.Filter.by_property("type").equal(ecosystem)
181+
]),
181182
)
182-
else:
183+
response.objects = [
184+
obj
185+
for obj in response.objects
186+
if obj.properties["name"].lower() in packages
187+
and obj.properties["type"].lower() == ecosystem.lower()
188+
]
189+
elif packages and not ecosystem:
190+
response = collection.query.fetch_objects(
191+
filters=wvc.query.Filter.all_of([
192+
wvc.query.Filter.by_property("name").contains_any(packages),
193+
]),
194+
)
195+
response.objects = [
196+
obj
197+
for obj in response.objects
198+
if obj.properties["name"].lower() in packages
199+
]
200+
elif query:
201+
# Perform the vector search
202+
# Generate the vector for the query
203+
query_vector = await self.inference_engine.embed(self.model_path, [query])
204+
183205
response = collection.query.near_vector(
184206
query_vector[0],
185207
limit=limit,

tests/test_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def test_search(mock_weaviate_client, mock_inference_engine):
5252
storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data")
5353

5454
# Invoke the search method
55-
results = await storage_engine.search("test query", 5, 0.3)
55+
results = await storage_engine.search(query="test query", limit=5, distance=0.3)
5656

5757
# Assertions to validate the expected behavior
5858
assert len(results) == 1 # Assert that one result is returned

0 commit comments

Comments
 (0)