Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Modify storageclient to singleton pattern #223

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG
##

# Model to use for chatting
chat_model_path: "./models"
model_base_path: "./models"

# Context length of the model
chat_model_n_ctx: 32768

# Number of layers to offload to GPU. If -1, all layers are offloaded.
chat_model_n_gpu_layers: -1

# Embedding model
embedding_model: "all-minilm-L6-v2-q5_k_m.gguf"
4 changes: 3 additions & 1 deletion src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Config:
model_base_path: str = "./models"
chat_model_n_ctx: int = 32768
chat_model_n_gpu_layers: int = -1
embedding_model: str = "all-minilm-L6-v2-q5_k_m.gguf"

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down Expand Up @@ -117,11 +118,12 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
host=config_data.get("host", cls.host),
log_level=config_data.get("log_level", cls.log_level.value),
log_format=config_data.get("log_format", cls.log_format.value),
model_base_path=config_data.get("chat_model_path", cls.model_base_path),
model_base_path=config_data.get("model_base_path", cls.model_base_path),
chat_model_n_ctx=config_data.get("chat_model_n_ctx", cls.chat_model_n_ctx),
chat_model_n_gpu_layers=config_data.get(
"chat_model_n_gpu_layers", cls.chat_model_n_gpu_layers
),
embedding_model=config_data.get("embedding_model", cls.embedding_model),
prompts=prompts_config,
provider_urls=provider_urls,
)
Expand Down
6 changes: 2 additions & 4 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class CodegateContextRetriever(PipelineStep):
the word "codegate" in the user message.
"""

def __init__(self):
self.storage_engine = StorageEngine()

@property
def name(self) -> str:
"""
Expand All @@ -33,7 +30,8 @@ def name(self) -> str:
async def get_objects_from_search(
self, search: str, packages: list[str] = None
) -> list[object]:
objects = await self.storage_engine.search(search, distance=0.8, packages=packages)
storage_engine = StorageEngine()
objects = await storage_engine.search(search, distance=0.8, packages=packages)
return objects

def generate_context_str(self, objects: list[object]) -> str:
Expand Down
6 changes: 2 additions & 4 deletions src/codegate/pipeline/extract_snippets/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
class CodeCommentStep(OutputPipelineStep):
"""Pipeline step that adds comments after code blocks"""

def __init__(self):
self._storage_engine = StorageEngine()

@property
def name(self) -> str:
return "code-comment"
Expand Down Expand Up @@ -52,7 +49,8 @@ async def _snippet_comment(self, snippet: CodeSnippet, secrets: PipelineSensitiv
base_url=secrets.api_base,
)

libobjects = await self._storage_engine.search_by_property("name", snippet.libraries)
storage_engine = StorageEngine()
libobjects = await storage_engine.search_by_property("name", snippet.libraries)
logger.info(f"Found {len(libobjects)} libraries in the storage engine")

libraries_text = ""
Expand Down
96 changes: 48 additions & 48 deletions src/codegate/storage/storage_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,60 @@


class StorageEngine:
def get_client(self, data_path):
__storage_engine = None

def __new__(cls, *args, **kwargs):
if cls.__storage_engine is None:
cls.__storage_engine = super().__new__(cls)
return cls.__storage_engine

# This function is needed only for the unit testing for the
# mocks to work.
@classmethod
def recreate_instance(cls, *args, **kwargs):
cls.__storage_engine = None
return cls(*args, **kwargs)

def __init__(self, data_path="./weaviate_data"):
if hasattr(self, "initialized"):
return

self.initialized = True
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = (
f"{Config.get_config().model_base_path}/{Config.get_config().embedding_model}"
)
self.schema_config = schema_config

# setup schema for weaviate
self.weaviate_client = self.get_client(self.data_path)
if self.weaviate_client is not None:
try:
self.weaviate_client.connect()
self.setup_schema(self.weaviate_client)
except Exception as e:
logger.error(f"Failed to connect or setup schema: {str(e)}")

def __del__(self):
try:
# Get current config
config = Config.get_config()
self.weaviate_client.close()
except Exception as e:
logger.error(f"Failed to close client: {str(e)}")

def get_client(self, data_path):
try:
# Configure Weaviate logging
additional_env_vars = {
# Basic logging configuration
"LOG_FORMAT": config.log_format.value.lower(),
"LOG_LEVEL": config.log_level.value.lower(),
"LOG_FORMAT": Config.get_config().log_format.value.lower(),
"LOG_LEVEL": Config.get_config().log_level.value.lower(),
# Disable colored output
"LOG_FORCE_COLOR": "false",
# Configure JSON format
"LOG_JSON_FIELDS": "timestamp, level,message",
# Configure text format
"LOG_METHOD": config.log_format.value.lower(),
"LOG_METHOD": Config.get_config().log_format.value.lower(),
"LOG_LEVEL_IN_UPPER": "false", # Keep level lowercase like codegate format
# Disable additional fields
"LOG_GIT_HASH": "false",
Expand All @@ -60,28 +98,6 @@ def get_client(self, data_path):
logger.error(f"Error during client creation: {str(e)}")
return None

def __init__(self, data_path="./weaviate_data"):
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
self.schema_config = schema_config

# setup schema for weaviate
weaviate_client = self.get_client(self.data_path)
if weaviate_client is not None:
try:
weaviate_client.connect()
self.setup_schema(weaviate_client)
except Exception as e:
logger.error(f"Failed to connect or setup schema: {str(e)}")
finally:
try:
weaviate_client.close()
except Exception as e:
logger.error(f"Failed to close client: {str(e)}")
else:
logger.error("Could not find client, skipping schema setup.")

def setup_schema(self, client):
for class_config in self.schema_config:
if not client.collections.exists(class_config["name"]):
Expand All @@ -95,18 +111,16 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
return []

# Perform the vector search
weaviate_client = self.get_client(self.data_path)
if weaviate_client is None:
if self.weaviate_client is None:
logger.error("Could not find client, not returning results.")
return []

if not weaviate_client:
if not self.weaviate_client:
logger.error("Invalid client, cannot perform search.")
return []

try:
weaviate_client.connect()
packages = weaviate_client.collections.get("Package")
packages = self.weaviate_client.collections.get("Package")
response = packages.query.fetch_objects(
filters=Filter.by_property(name).contains_any(properties),
)
Expand All @@ -117,8 +131,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
return []
finally:
weaviate_client.close()

async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list[object]:
"""
Expand All @@ -135,14 +147,8 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list
query_vector = await self.inference_engine.embed(self.model_path, [query])

# Perform the vector search
weaviate_client = self.get_client(self.data_path)
if weaviate_client is None:
logger.error("Could not find client, not returning results.")
return []

try:
weaviate_client.connect()
collection = weaviate_client.collections.get("Package")
collection = self.weaviate_client.collections.get("Package")
if packages:
response = collection.query.near_vector(
query_vector[0],
Expand All @@ -159,16 +165,10 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list
return_metadata=MetadataQuery(distance=True),
)

weaviate_client.close()
if not response:
return []
return response.objects

except Exception as e:
logger.error(f"Error during search: {str(e)}")
return []
finally:
try:
weaviate_client.close()
except Exception as e:
logger.error(f"Failed to close client: {str(e)}")
16 changes: 10 additions & 6 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from codegate.config import Config
from codegate.storage.storage_engine import (
StorageEngine,
) # Adjust the import based on your actual path
Expand Down Expand Up @@ -34,17 +35,21 @@ def mock_inference_engine():

@pytest.mark.asyncio
async def test_search(mock_weaviate_client, mock_inference_engine):
Config.load(config_path="./config.yaml")

# Patch the LlamaCppInferenceEngine.embed method (not the entire class)
with patch(
"codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
mock_inference_engine.embed,
):

# Mock the WeaviateClient as before
with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client):

# Initialize StorageEngine
with patch(
"codegate.storage.storage_engine.StorageEngine.get_client",
return_value=mock_weaviate_client,
):
# Initialize StorageEngine
storage_engine = StorageEngine(data_path="./weaviate_data")
# Need to recreate instance to use the mock
storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data")

# Invoke the search method
results = await storage_engine.search("test query", 5, 0.3)
Expand All @@ -53,4 +58,3 @@ async def test_search(mock_weaviate_client, mock_inference_engine):
assert len(results) == 1 # Assert that one result is returned
assert results[0]["properties"]["name"] == "test"
mock_weaviate_client.connect.assert_called()
mock_weaviate_client.close.assert_called()
Loading