From d1725d535d64bc9c6e3049681975b61179e330bd Mon Sep 17 00:00:00 2001 From: Pankaj Telang Date: Fri, 6 Dec 2024 10:17:56 -0500 Subject: [PATCH] Modify storageclient to singleton pattern --- config.yaml | 4 +- src/codegate/config.py | 4 +- .../codegate_context_retriever/codegate.py | 6 +- .../pipeline/extract_snippets/output.py | 6 +- src/codegate/storage/storage_engine.py | 96 +++++++++---------- tests/test_storage.py | 16 ++-- 6 files changed, 68 insertions(+), 64 deletions(-) diff --git a/config.yaml b/config.yaml index fb4be274..60d8a63b 100644 --- a/config.yaml +++ b/config.yaml @@ -19,7 +19,7 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG ## # Model to use for chatting -chat_model_path: "./models" +model_base_path: "./models" # Context length of the model chat_model_n_ctx: 32768 @@ -27,3 +27,5 @@ chat_model_n_ctx: 32768 # Number of layers to offload to GPU. If -1, all layers are offloaded. chat_model_n_gpu_layers: -1 +# Embedding model +embedding_model: "all-minilm-L6-v2-q5_k_m.gguf" \ No newline at end of file diff --git a/src/codegate/config.py b/src/codegate/config.py index fafd1a66..1f73cdb4 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -40,6 +40,7 @@ class Config: model_base_path: str = "./models" chat_model_n_ctx: int = 32768 chat_model_n_gpu_layers: int = -1 + embedding_model: str = "all-minilm-L6-v2-q5_k_m.gguf" # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) @@ -117,11 +118,12 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config": host=config_data.get("host", cls.host), log_level=config_data.get("log_level", cls.log_level.value), log_format=config_data.get("log_format", cls.log_format.value), - model_base_path=config_data.get("chat_model_path", cls.model_base_path), + model_base_path=config_data.get("model_base_path", cls.model_base_path), chat_model_n_ctx=config_data.get("chat_model_n_ctx", cls.chat_model_n_ctx), chat_model_n_gpu_layers=config_data.get( "chat_model_n_gpu_layers", cls.chat_model_n_gpu_layers ), + embedding_model=config_data.get("embedding_model", cls.embedding_model), prompts=prompts_config, provider_urls=provider_urls, ) diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 2fab586c..a31cdd8c 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -20,9 +20,6 @@ class CodegateContextRetriever(PipelineStep): the word "codegate" in the user message. """ - def __init__(self): - self.storage_engine = StorageEngine() - @property def name(self) -> str: """ @@ -33,7 +30,8 @@ def name(self) -> str: async def get_objects_from_search( self, search: str, packages: list[str] = None ) -> list[object]: - objects = await self.storage_engine.search(search, distance=0.8, packages=packages) + storage_engine = StorageEngine() + objects = await storage_engine.search(search, distance=0.8, packages=packages) return objects def generate_context_str(self, objects: list[object]) -> str: diff --git a/src/codegate/pipeline/extract_snippets/output.py b/src/codegate/pipeline/extract_snippets/output.py index bf9243d6..cfa50425 100644 --- a/src/codegate/pipeline/extract_snippets/output.py +++ b/src/codegate/pipeline/extract_snippets/output.py @@ -16,9 +16,6 @@ class CodeCommentStep(OutputPipelineStep): """Pipeline step that adds comments after code blocks""" - def __init__(self): - self._storage_engine = StorageEngine() - @property def name(self) -> str: return "code-comment" @@ -52,7 +49,8 @@ async def _snippet_comment(self, snippet: CodeSnippet, secrets: PipelineSensitiv base_url=secrets.api_base, ) - libobjects = await self._storage_engine.search_by_property("name", snippet.libraries) + storage_engine = StorageEngine() + libobjects = await storage_engine.search_by_property("name", snippet.libraries) logger.info(f"Found {len(libobjects)} libraries in the storage engine") libraries_text = "" diff --git a/src/codegate/storage/storage_engine.py b/src/codegate/storage/storage_engine.py index d4421117..caf193cd 100644 --- a/src/codegate/storage/storage_engine.py +++ b/src/codegate/storage/storage_engine.py @@ -26,22 +26,60 @@ class StorageEngine: - def get_client(self, data_path): + __storage_engine = None + + def __new__(cls, *args, **kwargs): + if cls.__storage_engine is None: + cls.__storage_engine = super().__new__(cls) + return cls.__storage_engine + + # This function is needed only for the unit testing for the + # mocks to work. + @classmethod + def recreate_instance(cls, *args, **kwargs): + cls.__storage_engine = None + return cls(*args, **kwargs) + + def __init__(self, data_path="./weaviate_data"): + if hasattr(self, "initialized"): + return + + self.initialized = True + self.data_path = data_path + self.inference_engine = LlamaCppInferenceEngine() + self.model_path = ( + f"{Config.get_config().model_base_path}/{Config.get_config().embedding_model}" + ) + self.schema_config = schema_config + + # setup schema for weaviate + self.weaviate_client = self.get_client(self.data_path) + if self.weaviate_client is not None: + try: + self.weaviate_client.connect() + self.setup_schema(self.weaviate_client) + except Exception as e: + logger.error(f"Failed to connect or setup schema: {str(e)}") + + def __del__(self): try: - # Get current config - config = Config.get_config() + self.weaviate_client.close() + except Exception as e: + logger.error(f"Failed to close client: {str(e)}") + def get_client(self, data_path): + try: # Configure Weaviate logging additional_env_vars = { # Basic logging configuration - "LOG_FORMAT": config.log_format.value.lower(), - "LOG_LEVEL": config.log_level.value.lower(), + "LOG_FORMAT": Config.get_config().log_format.value.lower(), + "LOG_LEVEL": Config.get_config().log_level.value.lower(), # Disable colored output "LOG_FORCE_COLOR": "false", # Configure JSON format "LOG_JSON_FIELDS": "timestamp, level,message", # Configure text format - "LOG_METHOD": config.log_format.value.lower(), + "LOG_METHOD": Config.get_config().log_format.value.lower(), "LOG_LEVEL_IN_UPPER": "false", # Keep level lowercase like codegate format # Disable additional fields "LOG_GIT_HASH": "false", @@ -60,28 +98,6 @@ def get_client(self, data_path): logger.error(f"Error during client creation: {str(e)}") return None - def __init__(self, data_path="./weaviate_data"): - self.data_path = data_path - self.inference_engine = LlamaCppInferenceEngine() - self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" - self.schema_config = schema_config - - # setup schema for weaviate - weaviate_client = self.get_client(self.data_path) - if weaviate_client is not None: - try: - weaviate_client.connect() - self.setup_schema(weaviate_client) - except Exception as e: - logger.error(f"Failed to connect or setup schema: {str(e)}") - finally: - try: - weaviate_client.close() - except Exception as e: - logger.error(f"Failed to close client: {str(e)}") - else: - logger.error("Could not find client, skipping schema setup.") - def setup_schema(self, client): for class_config in self.schema_config: if not client.collections.exists(class_config["name"]): @@ -95,18 +111,16 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj return [] # Perform the vector search - weaviate_client = self.get_client(self.data_path) - if weaviate_client is None: + if self.weaviate_client is None: logger.error("Could not find client, not returning results.") return [] - if not weaviate_client: + if not self.weaviate_client: logger.error("Invalid client, cannot perform search.") return [] try: - weaviate_client.connect() - packages = weaviate_client.collections.get("Package") + packages = self.weaviate_client.collections.get("Package") response = packages.query.fetch_objects( filters=Filter.by_property(name).contains_any(properties), ) @@ -117,8 +131,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj except Exception as e: logger.error(f"An error occurred: {str(e)}") return [] - finally: - weaviate_client.close() async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list[object]: """ @@ -135,14 +147,8 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list query_vector = await self.inference_engine.embed(self.model_path, [query]) # Perform the vector search - weaviate_client = self.get_client(self.data_path) - if weaviate_client is None: - logger.error("Could not find client, not returning results.") - return [] - try: - weaviate_client.connect() - collection = weaviate_client.collections.get("Package") + collection = self.weaviate_client.collections.get("Package") if packages: response = collection.query.near_vector( query_vector[0], @@ -159,7 +165,6 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list return_metadata=MetadataQuery(distance=True), ) - weaviate_client.close() if not response: return [] return response.objects @@ -167,8 +172,3 @@ async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list except Exception as e: logger.error(f"Error during search: {str(e)}") return [] - finally: - try: - weaviate_client.close() - except Exception as e: - logger.error(f"Failed to close client: {str(e)}") diff --git a/tests/test_storage.py b/tests/test_storage.py index 2b7b9c16..965bf071 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -2,6 +2,7 @@ import pytest +from codegate.config import Config from codegate.storage.storage_engine import ( StorageEngine, ) # Adjust the import based on your actual path @@ -34,17 +35,21 @@ def mock_inference_engine(): @pytest.mark.asyncio async def test_search(mock_weaviate_client, mock_inference_engine): + Config.load(config_path="./config.yaml") + # Patch the LlamaCppInferenceEngine.embed method (not the entire class) with patch( "codegate.inference.inference_engine.LlamaCppInferenceEngine.embed", mock_inference_engine.embed, ): - - # Mock the WeaviateClient as before - with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client): - + # Initialize StorageEngine + with patch( + "codegate.storage.storage_engine.StorageEngine.get_client", + return_value=mock_weaviate_client, + ): # Initialize StorageEngine - storage_engine = StorageEngine(data_path="./weaviate_data") + # Need to recreate instance to use the mock + storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data") # Invoke the search method results = await storage_engine.search("test query", 5, 0.3) @@ -53,4 +58,3 @@ async def test_search(mock_weaviate_client, mock_inference_engine): assert len(results) == 1 # Assert that one result is returned assert results[0]["properties"]["name"] == "test" mock_weaviate_client.connect.assert_called() - mock_weaviate_client.close.assert_called()