Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 92895f4

Browse files
authored
Merge pull request #272 from stacklok/issue-257
feat: infer ecosystem from the user's query
2 parents 7d0c231 + 9eb48c0 commit 92895f4

File tree

5 files changed

+71
-9
lines changed

5 files changed

+71
-9
lines changed

prompts/default.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ lookup_packages: |
3131
Assume that a package can be any named entity.
3232
You MUST RESPOND with a list of packages in JSON FORMAT: {"packages": ["pkg1", "pkg2", ...]}.
3333
34+
lookup_ecosystem: |
35+
You are a software expert with knowledge of various programming languages ecosystems.
36+
When given a query related to coding or programming tasks, your job is to determine
37+
the associated programming language and then infer the corresponding language ecosystem
38+
based on the context provided in the query.
39+
Valid ecosystems are: pypi (Python), npm (Node.js), maven (Java), crates (Rust), go (golang).
40+
If you are not sure or you cannot infer it, please respond with an empty value.
41+
You MUST RESPOND with a JSON dictionary on this format: {"ecosystem": "ecosystem_name"}.
42+
3443
secrets_redacted: |
3544
The files in the context contain sensitive information that has been redacted. Do not warn the user
3645
about any tokens, passwords or similar sensitive information in the context whose value begins with

src/codegate/llm_utils/extractor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,29 @@ async def extract_packages(
4141
packages = result if isinstance(result, list) else result.get("packages", [])
4242
logger.info(f"Extracted packages: {packages}")
4343
return packages
44+
45+
@staticmethod
46+
async def extract_ecosystem(
47+
content: str,
48+
provider: str,
49+
model: str = None,
50+
base_url: Optional[str] = None,
51+
api_key: Optional[str] = None,
52+
) -> List[str]:
53+
"""Extract ecosystem from the given content."""
54+
system_prompt = Config.get_config().prompts.lookup_ecosystem
55+
56+
result = await LLMClient.complete(
57+
content=content,
58+
system_prompt=system_prompt,
59+
provider=provider,
60+
model=model,
61+
api_key=api_key,
62+
base_url=base_url,
63+
)
64+
65+
ecosystem = result if isinstance(result, str) else result.get("ecosystem")
66+
if ecosystem:
67+
ecosystem = ecosystem.lower()
68+
logger.info(f"Extracted ecosystem: {ecosystem}")
69+
return ecosystem

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ def name(self) -> str:
3030
return "codegate-context-retriever"
3131

3232
async def get_objects_from_search(
33-
self, search: str, packages: list[str] = None
33+
self, search: str, ecosystem, packages: list[str] = None
3434
) -> list[object]:
3535
storage_engine = StorageEngine()
36-
objects = await storage_engine.search(search, distance=0.8, packages=packages)
36+
objects = await storage_engine.search(
37+
search, distance=0.8, ecosystem=ecosystem, packages=packages)
3738
return objects
3839

3940
def generate_context_str(self, objects: list[object], context: PipelineContext) -> str:
@@ -69,6 +70,19 @@ async def __lookup_packages(self, user_query: str, context: PipelineContext):
6970
logger.info(f"Packages in user query: {packages}")
7071
return packages
7172

73+
async def __lookup_ecosystem(self, user_query: str, context: PipelineContext):
74+
# Use PackageExtractor to extract ecosystem from the user query
75+
ecosystem = await PackageExtractor.extract_ecosystem(
76+
content=user_query,
77+
provider=context.sensitive.provider,
78+
model=context.sensitive.model,
79+
api_key=context.sensitive.api_key,
80+
base_url=context.sensitive.api_base,
81+
)
82+
83+
logger.info(f"Ecosystem in user query: {ecosystem}")
84+
return ecosystem
85+
7286
async def process(
7387
self, request: ChatCompletionRequest, context: PipelineContext
7488
) -> PipelineResult:
@@ -85,6 +99,7 @@ async def process(
8599

86100
# Extract packages from the user message
87101
last_user_message_str, last_user_idx = last_user_message
102+
ecosystem = await self.__lookup_ecosystem(last_user_message_str, context)
88103
packages = await self.__lookup_packages(last_user_message_str, context)
89104
packages = [pkg.lower() for pkg in packages]
90105

@@ -93,7 +108,8 @@ async def process(
93108
return PipelineResult(request=request)
94109

95110
# Look for matches in vector DB using list of packages as filter
96-
searched_objects = await self.get_objects_from_search(last_user_message_str, packages)
111+
searched_objects = await self.get_objects_from_search(
112+
last_user_message_str, ecosystem, packages)
97113

98114
logger.info(
99115
f"Found {len(searched_objects)} matches in the database",

src/codegate/storage/storage_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1212

1313
logger = structlog.get_logger("codegate")
14+
VALID_ECOSYSTEMS = ["npm", "pypi", "crates", "maven", "go"]
1415

1516
schema_config = [
1617
{
@@ -120,6 +121,7 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
120121
return []
121122

122123
try:
124+
123125
packages = self.weaviate_client.collections.get("Package")
124126
response = packages.query.fetch_objects(
125127
filters=Filter.by_property(name).contains_any(properties),
@@ -142,13 +144,17 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
142144
logger.error(f"An error occurred: {str(e)}")
143145
return []
144146

145-
async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list[object]:
147+
async def search(self, query: str, limit=5, distance=0.3,
148+
ecosystem=None, packages=None) -> list[object]:
146149
"""
147150
Search the 'Package' collection based on a query string.
148151
149152
Args:
150153
query (str): The text query for which to search.
151154
limit (int): The number of results to return.
155+
distance (float): The distance threshold for the search.
156+
ecosystem (str): The ecosystem to search in.
157+
packages (list): The list of packages to filter the search.
152158
153159
Returns:
154160
list: A list of matching results with their properties and distances.
@@ -160,11 +166,16 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list
160166
try:
161167
collection = self.weaviate_client.collections.get("Package")
162168
if packages:
169+
# filter by packages and ecosystem if present
170+
filters = []
171+
if ecosystem and ecosystem in VALID_ECOSYSTEMS:
172+
filters.append(wvc.query.Filter.by_property("type").equal(ecosystem))
173+
filters.append(wvc.query.Filter.by_property("name").contains_any(packages))
163174
response = collection.query.near_vector(
164175
query_vector[0],
165176
limit=limit,
166177
distance=distance,
167-
filters=wvc.query.Filter.by_property("name").contains_any(packages),
178+
filters=wvc.query.Filter.all_of(filters),
168179
return_metadata=MetadataQuery(distance=True),
169180
)
170181
else:

tests/test_cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_serve_default_options(
7474
"port": 8989,
7575
"log_level": "INFO",
7676
"log_format": "JSON",
77-
"prompts_loaded": 6,
77+
"prompts_loaded": 7,
7878
"provider_urls": DEFAULT_PROVIDER_URLS,
7979
}
8080

@@ -123,7 +123,7 @@ def test_serve_custom_options(
123123
"port": 8989,
124124
"log_level": "DEBUG",
125125
"log_format": "TEXT",
126-
"prompts_loaded": 6, # Default prompts are loaded
126+
"prompts_loaded": 7, # Default prompts are loaded
127127
"provider_urls": DEFAULT_PROVIDER_URLS,
128128
}
129129

@@ -170,7 +170,7 @@ def test_serve_with_config_file(
170170
"port": 8989,
171171
"log_level": "DEBUG",
172172
"log_format": "JSON",
173-
"prompts_loaded": 6, # Default prompts are loaded
173+
"prompts_loaded": 7, # Default prompts are loaded
174174
"provider_urls": DEFAULT_PROVIDER_URLS,
175175
}
176176

@@ -229,7 +229,7 @@ def test_serve_priority_resolution(
229229
"port": 8080,
230230
"log_level": "ERROR",
231231
"log_format": "TEXT",
232-
"prompts_loaded": 6, # Default prompts are loaded
232+
"prompts_loaded": 7, # Default prompts are loaded
233233
"provider_urls": DEFAULT_PROVIDER_URLS,
234234
}
235235

0 commit comments

Comments
 (0)