Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Run filter using the specified packages/ecosystem #434

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/codegate/llm_utils/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def extract_packages(
system_prompt = Config.get_config().prompts.lookup_packages

result = await LLMClient.complete(
content=content,
content=content.lower(),
system_prompt=system_prompt,
provider=provider,
model=model,
Expand All @@ -41,6 +41,9 @@ async def extract_packages(

# Handle both formats: {"packages": [...]} and direct list [...]
packages = result if isinstance(result, list) else result.get("packages", [])

# Filter packages based on the content
packages = [package.lower() for package in packages if package.lower() in content]
logger.info(f"Extracted packages: {packages}")
return packages

Expand Down
44 changes: 15 additions & 29 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def name(self) -> str:
"""
return "codegate-context-retriever"

async def get_objects_from_search(
self, search: str, ecosystem, packages: list[str] = None
async def get_objects_from_db(
self, ecosystem, packages: list[str] = None
) -> list[object]:
storage_engine = StorageEngine()
objects = await storage_engine.search(
search, distance=0.8, ecosystem=ecosystem, packages=packages
distance=0.8, ecosystem=ecosystem, packages=packages
)
return objects

Expand Down Expand Up @@ -103,39 +103,25 @@ async def process(
# Extract packages from the user message
ecosystem = await self.__lookup_ecosystem(user_messages, context)
packages = await self.__lookup_packages(user_messages, context)
packages = [pkg.lower() for pkg in packages]

# If user message does not reference any packages, then just return
if len(packages) == 0:
return PipelineResult(request=request)

# Look for matches in vector DB using list of packages as filter
searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages)
context_str = "CodeGate did not find any malicious or archived packages."

logger.info(
f"Found {len(searched_objects)} matches in the database",
searched_objects=searched_objects,
)
if len(packages) > 0:
# Look for matches in DB using packages and ecosystem
searched_objects = await self.get_objects_from_db(ecosystem, packages)

# Remove searched objects that are not in packages. This is needed
# since Weaviate performs substring match in the filter.
updated_searched_objects = []
for searched_object in searched_objects:
if searched_object.properties["name"].lower() in packages:
updated_searched_objects.append(searched_object)
searched_objects = updated_searched_objects
logger.info(
f"Found {len(searched_objects)} matches in the database",
searched_objects=searched_objects,
)

# Generate context string using the searched objects
logger.info(f"Adding {len(searched_objects)} packages to the context")
# Generate context string using the searched objects
logger.info(f"Adding {len(searched_objects)} packages to the context")

if len(searched_objects) > 0:
context_str = self.generate_context_str(searched_objects, context)
else:
context_str = "CodeGate did not find any malicious or archived packages."
if len(searched_objects) > 0:
context_str = self.generate_context_str(searched_objects, context)

last_user_idx = self.get_last_user_message_idx(request)
if last_user_idx == -1:
return PipelineResult(request=request, context=context)

# Make a copy of the request
new_request = request.copy()
Expand Down
60 changes: 41 additions & 19 deletions src/codegate/storage/storage_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
return []

try:

packages = self.weaviate_client.collections.get("Package")
response = packages.query.fetch_objects(
filters=Filter.by_property(name).contains_any(properties),
Expand All @@ -145,10 +144,19 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
return []

async def search(
self, query: str, limit=5, distance=0.3, ecosystem=None, packages=None
self,
query: str = None,
ecosystem: str = None,
packages: List[str] = None,
limit: int = 5,
distance: float = 0.3,
) -> list[object]:
"""
Search the 'Package' collection based on a query string.
Search the 'Package' collection based on a query string, ecosystem and packages.
If packages and ecosystem are both not none, then filter the objects using them.
If packages is not none and ecosystem is none, then filter the objects using
package names.
If packages is none, then perform vector search.

Args:
query (str): The text query for which to search.
Expand All @@ -160,26 +168,40 @@ async def search(
Returns:
list: A list of matching results with their properties and distances.
"""
# Generate the vector for the query
query_vector = await self.inference_engine.embed(self.model_path, [query])

# Perform the vector search
try:
collection = self.weaviate_client.collections.get("Package")
if packages:
# filter by packages and ecosystem if present
filters = []
if ecosystem and ecosystem in VALID_ECOSYSTEMS:
filters.append(wvc.query.Filter.by_property("type").equal(ecosystem))
filters.append(wvc.query.Filter.by_property("name").contains_any(packages))
response = collection.query.near_vector(
query_vector[0],
limit=limit,
distance=distance,
filters=wvc.query.Filter.all_of(filters),
return_metadata=MetadataQuery(distance=True),

response = None
if packages and ecosystem and ecosystem in VALID_ECOSYSTEMS:
response = collection.query.fetch_objects(
filters=wvc.query.Filter.all_of([
wvc.query.Filter.by_property("name").contains_any(packages),
wvc.query.Filter.by_property("type").equal(ecosystem)
]),
)
else:
response.objects = [
obj
for obj in response.objects
if obj.properties["name"].lower() in packages
and obj.properties["type"].lower() == ecosystem.lower()
]
elif packages and not ecosystem:
response = collection.query.fetch_objects(
filters=wvc.query.Filter.all_of([
wvc.query.Filter.by_property("name").contains_any(packages),
]),
)
response.objects = [
obj
for obj in response.objects
if obj.properties["name"].lower() in packages
]
elif query:
# Perform the vector search
# Generate the vector for the query
query_vector = await self.inference_engine.embed(self.model_path, [query])

response = collection.query.near_vector(
query_vector[0],
limit=limit,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def test_search(mock_weaviate_client, mock_inference_engine):
storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data")

# Invoke the search method
results = await storage_engine.search("test query", 5, 0.3)
results = await storage_engine.search(query="test query", limit=5, distance=0.3)

# Assertions to validate the expected behavior
assert len(results) == 1 # Assert that one result is returned
Expand Down
Loading