Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Extract packages in both input and output using the LLM the user called with #214

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shutil


import weaviate
from weaviate.classes.config import DataType, Property
from weaviate.embedded import EmbeddedOptions
Expand All @@ -17,10 +16,12 @@ class PackageImporter:
def __init__(self):
self.client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data", grpc_port=50052,
additional_env_vars={"ENABLE_MODULES": "backup-filesystem",
"BACKUP_FILESYSTEM_PATH": os.getenv("BACKUP_FILESYSTEM_PATH",
"/tmp")}
persistence_data_path="./weaviate_data",
grpc_port=50052,
additional_env_vars={
"ENABLE_MODULES": "backup-filesystem",
"BACKUP_FILESYSTEM_PATH": os.getenv("BACKUP_FILESYSTEM_PATH", "/tmp"),
},
)
)
self.json_files = [
Expand All @@ -35,21 +36,28 @@ def __init__(self):
def restore_backup(self):
if os.getenv("BACKUP_FOLDER"):
try:
self.client.backup.restore(backup_id=os.getenv("BACKUP_FOLDER"),
backend="filesystem", wait_for_completion=True)
self.client.backup.restore(
backup_id=os.getenv("BACKUP_FOLDER"),
backend="filesystem",
wait_for_completion=True,
)
except Exception as e:
print(f"Failed to restore backup: {e}")

def take_backup(self):
# if backup folder exists, remove it
backup_path = os.path.join(os.getenv("BACKUP_FILESYSTEM_PATH", "/tmp"),
os.getenv("BACKUP_TARGET_ID", "backup"))
backup_path = os.path.join(
os.getenv("BACKUP_FILESYSTEM_PATH", "/tmp"), os.getenv("BACKUP_TARGET_ID", "backup")
)
if os.path.exists(backup_path):
shutil.rmtree(backup_path)

#  take a backup of the data
self.client.backup.create(backup_id=os.getenv("BACKUP_TARGET_ID", "backup"),
backend="filesystem", wait_for_completion=True)
self.client.backup.create(
backup_id=os.getenv("BACKUP_TARGET_ID", "backup"),
backend="filesystem",
wait_for_completion=True,
)

def setup_schema(self):
if not self.client.collections.exists("Package"):
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import Dict, Optional

import click
from src.codegate.storage.utils import restore_storage_backup
import structlog

from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.db.connection import init_db_sync
from codegate.server import init_app
from src.codegate.storage.utils import restore_storage_backup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use "src" in the package name here. It should be: "from codegate.storage.utils ..."



def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int:
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/llm_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from src.codegate.llm_utils.extractor import PackageExtractor
from src.codegate.llm_utils.llmclient import LLMClient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above re: "src".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix in a follow up


__all__ = ["LLMClient", "PackageExtractor"]
39 changes: 39 additions & 0 deletions src/codegate/llm_utils/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Optional

import structlog

from codegate.config import Config
from codegate.llm_utils.llmclient import LLMClient

logger = structlog.get_logger("codegate")


class PackageExtractor:
"""
Utility class to extract package names from code or queries.
"""

@staticmethod
async def extract_packages(
content: str,
provider: str,
model: str = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> List[str]:
"""Extract package names from the given content."""
system_prompt = Config.get_config().prompts.lookup_packages

result = await LLMClient.complete(
content=content,
system_prompt=system_prompt,
provider=provider,
model=model,
api_key=api_key,
base_url=base_url,
)

# Handle both formats: {"packages": [...]} and direct list [...]
packages = result if isinstance(result, list) else result.get("packages", [])
logger.info(f"Extracted packages: {packages}")
return packages
136 changes: 136 additions & 0 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
from typing import Any, Dict, Optional

import structlog
from litellm import acompletion

from codegate.config import Config
from codegate.inference import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")


class LLMClient:
"""
Base class for LLM interactions handling both local and cloud providers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class only handles chat requests currently. Do we need to implement code-completion support in it eventually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FIM pipelines we will


This is a kludge before we refactor our providers a bit to be able to pass
in all the parameters we need.
"""

@staticmethod
async def complete(
content: str,
system_prompt: str,
provider: str,
model: str = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Send a completion request to either local or cloud LLM.

Args:
content: The user message content
system_prompt: The system prompt to use
provider: "local" or "litellm"
model: Model identifier
api_key: API key for cloud providers
base_url: Base URL for cloud providers
**kwargs: Additional arguments for the completion request

Returns:
Parsed response from the LLM
"""
if provider == "llamacpp":
return await LLMClient._complete_local(content, system_prompt, model, **kwargs)
return await LLMClient._complete_litellm(
content,
system_prompt,
provider,
model,
api_key,
base_url,
**kwargs,
)

@staticmethod
async def _create_request(
content: str, system_prompt: str, model: str, **kwargs
) -> Dict[str, Any]:
"""
Private method to create a request dictionary for LLM completion.
"""
return {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
],
"model": model,
"stream": False,
"response_format": {"type": "json_object"},
"temperature": kwargs.get("temperature", 0),
}

@staticmethod
async def _complete_local(
content: str,
system_prompt: str,
model: str,
**kwargs,
) -> Dict[str, Any]:
# Use the private method to create the request
request = await LLMClient._create_request(content, system_prompt, model, **kwargs)

inference_engine = LlamaCppInferenceEngine()
result = await inference_engine.chat(
f"{Config.get_config().model_base_path}/{request['model']}.gguf",
n_ctx=Config.get_config().chat_model_n_ctx,
n_gpu_layers=Config.get_config().chat_model_n_gpu_layers,
**request,
)

return json.loads(result["choices"][0]["message"]["content"])

@staticmethod
async def _complete_litellm(
content: str,
system_prompt: str,
provider: str,
model: str,
api_key: str,
base_url: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
# Use the private method to create the request
request = await LLMClient._create_request(content, system_prompt, model, **kwargs)

# We should reuse the same logic in the provider
# but let's do that later
if provider == "vllm":
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
else:
model = f"{provider}/{model}"

try:
response = await acompletion(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
)

content = response["choices"][0]["message"]["content"]

# Clean up code blocks if present
if content.startswith("```"):
content = content.split("\n", 1)[1].rsplit("```", 1)[0].strip()

return json.loads(content)

except Exception as e:
logger.error(f"LiteLLM completion failed: {e}")
return {}
33 changes: 29 additions & 4 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class CodeSnippet:
code: The actual code content
"""

code: str
language: Optional[str]
filepath: Optional[str]
code: str
libraries: List[str] = field(default_factory=list)

def __post_init__(self):
if self.language is not None:
Expand All @@ -44,6 +45,10 @@ class AlertSeverity(Enum):
class PipelineSensitiveData:
manager: SecretsManager
session_id: str
api_key: Optional[str] = None
model: Optional[str] = None
provider: Optional[str] = None
api_base: Optional[str] = None

def secure_cleanup(self):
"""Securely cleanup sensitive data for this session"""
Expand All @@ -53,6 +58,14 @@ def secure_cleanup(self):
self.manager.cleanup_session(self.session_id)
self.session_id = ""

# Securely wipe the API key using the same method as secrets manager
if self.api_key is not None:
api_key_bytes = bytearray(self.api_key.encode())
self.manager.crypto.wipe_bytearray(api_key_bytes)
self.api_key = None

self.model = None


@dataclass
class PipelineContext:
Expand Down Expand Up @@ -202,21 +215,33 @@ def __init__(self, pipeline_steps: List[PipelineStep]):
self.pipeline_steps = pipeline_steps

async def process_request(
self, secret_manager: SecretsManager, request: ChatCompletionRequest, prompt_id: str
self,
secret_manager: SecretsManager,
request: ChatCompletionRequest,
provider: str,
prompt_id: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> PipelineResult:
"""
Process a request through all pipeline steps

Args:
request: The chat completion request to process
secret_manager: The secrets manager instance to gather sensitive data from the request
request: The chat completion request to process

Returns:
PipelineResult containing either a modified request or response structure
"""
context = PipelineContext()
context.sensitive = PipelineSensitiveData(
manager=secret_manager, session_id=str(uuid.uuid4())
manager=secret_manager,
session_id=str(uuid.uuid4()),
api_key=api_key,
model=model,
provider=provider,
api_base=api_base,
) # Generate a new session ID for each request
context.metadata["prompt_id"] = prompt_id
current_request = request
Expand Down
Loading