-
Notifications
You must be signed in to change notification settings - Fork 83
Extract packages in both input and output using the LLM the user called with #214
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from src.codegate.llm_utils.extractor import PackageExtractor | ||
from src.codegate.llm_utils.llmclient import LLMClient | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above re: "src". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will fix in a follow up |
||
|
||
__all__ = ["LLMClient", "PackageExtractor"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import List, Optional | ||
|
||
import structlog | ||
|
||
from codegate.config import Config | ||
from codegate.llm_utils.llmclient import LLMClient | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class PackageExtractor: | ||
""" | ||
Utility class to extract package names from code or queries. | ||
""" | ||
|
||
@staticmethod | ||
async def extract_packages( | ||
content: str, | ||
provider: str, | ||
model: str = None, | ||
base_url: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
) -> List[str]: | ||
"""Extract package names from the given content.""" | ||
system_prompt = Config.get_config().prompts.lookup_packages | ||
|
||
result = await LLMClient.complete( | ||
content=content, | ||
system_prompt=system_prompt, | ||
provider=provider, | ||
model=model, | ||
api_key=api_key, | ||
base_url=base_url, | ||
) | ||
|
||
# Handle both formats: {"packages": [...]} and direct list [...] | ||
packages = result if isinstance(result, list) else result.get("packages", []) | ||
logger.info(f"Extracted packages: {packages}") | ||
return packages |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import json | ||
from typing import Any, Dict, Optional | ||
|
||
import structlog | ||
from litellm import acompletion | ||
|
||
from codegate.config import Config | ||
from codegate.inference import LlamaCppInferenceEngine | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class LLMClient: | ||
""" | ||
Base class for LLM interactions handling both local and cloud providers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class only handles chat requests currently. Do we need to implement code-completion support in it eventually? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For FIM pipelines we will |
||
|
||
This is a kludge before we refactor our providers a bit to be able to pass | ||
in all the parameters we need. | ||
""" | ||
|
||
@staticmethod | ||
async def complete( | ||
content: str, | ||
system_prompt: str, | ||
provider: str, | ||
model: str = None, | ||
api_key: Optional[str] = None, | ||
base_url: Optional[str] = None, | ||
**kwargs, | ||
) -> Dict[str, Any]: | ||
""" | ||
Send a completion request to either local or cloud LLM. | ||
|
||
Args: | ||
content: The user message content | ||
system_prompt: The system prompt to use | ||
provider: "local" or "litellm" | ||
model: Model identifier | ||
api_key: API key for cloud providers | ||
base_url: Base URL for cloud providers | ||
**kwargs: Additional arguments for the completion request | ||
|
||
Returns: | ||
Parsed response from the LLM | ||
""" | ||
if provider == "llamacpp": | ||
return await LLMClient._complete_local(content, system_prompt, model, **kwargs) | ||
return await LLMClient._complete_litellm( | ||
content, | ||
system_prompt, | ||
provider, | ||
model, | ||
api_key, | ||
base_url, | ||
**kwargs, | ||
) | ||
|
||
@staticmethod | ||
async def _create_request( | ||
content: str, system_prompt: str, model: str, **kwargs | ||
) -> Dict[str, Any]: | ||
""" | ||
Private method to create a request dictionary for LLM completion. | ||
""" | ||
return { | ||
"messages": [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": content}, | ||
], | ||
"model": model, | ||
"stream": False, | ||
"response_format": {"type": "json_object"}, | ||
"temperature": kwargs.get("temperature", 0), | ||
} | ||
|
||
@staticmethod | ||
async def _complete_local( | ||
content: str, | ||
system_prompt: str, | ||
model: str, | ||
**kwargs, | ||
) -> Dict[str, Any]: | ||
# Use the private method to create the request | ||
request = await LLMClient._create_request(content, system_prompt, model, **kwargs) | ||
|
||
inference_engine = LlamaCppInferenceEngine() | ||
result = await inference_engine.chat( | ||
f"{Config.get_config().model_base_path}/{request['model']}.gguf", | ||
n_ctx=Config.get_config().chat_model_n_ctx, | ||
n_gpu_layers=Config.get_config().chat_model_n_gpu_layers, | ||
**request, | ||
) | ||
|
||
return json.loads(result["choices"][0]["message"]["content"]) | ||
|
||
@staticmethod | ||
async def _complete_litellm( | ||
content: str, | ||
system_prompt: str, | ||
provider: str, | ||
model: str, | ||
api_key: str, | ||
base_url: Optional[str] = None, | ||
**kwargs, | ||
) -> Dict[str, Any]: | ||
# Use the private method to create the request | ||
request = await LLMClient._create_request(content, system_prompt, model, **kwargs) | ||
|
||
# We should reuse the same logic in the provider | ||
# but let's do that later | ||
if provider == "vllm": | ||
if not base_url.endswith("/v1"): | ||
base_url = f"{base_url}/v1" | ||
else: | ||
model = f"{provider}/{model}" | ||
|
||
try: | ||
response = await acompletion( | ||
model=model, | ||
messages=request["messages"], | ||
api_key=api_key, | ||
temperature=request["temperature"], | ||
base_url=base_url, | ||
) | ||
|
||
content = response["choices"][0]["message"]["content"] | ||
|
||
# Clean up code blocks if present | ||
if content.startswith("```"): | ||
content = content.split("\n", 1)[1].rsplit("```", 1)[0].strip() | ||
|
||
return json.loads(content) | ||
|
||
except Exception as e: | ||
logger.error(f"LiteLLM completion failed: {e}") | ||
return {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use "src" in the package name here. It should be: "from codegate.storage.utils ..."