From 82561ebc5fce2922f2bbbeba29c5f3909497bd36 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Mar 2025 21:26:04 +0100 Subject: [PATCH] Strip code snippets through a client interface Different clients need different ways of stripping away the context in order to analyze just the human text in the user message. Let's start with this as the first method of a client interface. Related: #831 --- src/codegate/clients/interface.py | 61 +++++++++++++++++++ .../codegate_context_retriever/codegate.py | 11 ++-- 2 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 src/codegate/clients/interface.py diff --git a/src/codegate/clients/interface.py b/src/codegate/clients/interface.py new file mode 100644 index 00000000..53656c1a --- /dev/null +++ b/src/codegate/clients/interface.py @@ -0,0 +1,61 @@ +import re +from abc import ABC, abstractmethod +from typing import Dict, Type + +from codegate.clients.clients import ClientType + + +class ClientInterface(ABC): + """Secure interface for client-specific message processing""" + + @abstractmethod + def strip_code_snippets(self, message: str) -> str: + """Remove code blocks and file listings to prevent context pollution""" + pass + + +class GenericClient(ClientInterface): + """Default implementation with strict input validation""" + + _MARKDOWN_CODE_REGEX = re.compile(r"```.*?```", re.DOTALL) + _MARKDOWN_FILE_LISTING = re.compile(r"⋮...*?⋮...\n\n", flags=re.DOTALL) + _ENVIRONMENT_DETAILS = re.compile( + r".*?", flags=re.DOTALL + ) + + _CLI_REGEX = re.compile(r"^codegate\s+(.*)$", re.IGNORECASE) + + def strip_code_snippets(self, message: str) -> str: + message = self._MARKDOWN_CODE_REGEX.sub("", message) + message = self._MARKDOWN_FILE_LISTING.sub("", message) + message = self._ENVIRONMENT_DETAILS.sub("", message) + return message + + +class ClineClient(ClientInterface): + """Cline-specific client interface""" + + _CLINE_FILE_REGEX = re.compile( + r"(?i)<\s*file_content\s*[^>]*>.*?", re.DOTALL + ) + + def __init__(self): + self.generic_client = GenericClient() + + def strip_code_snippets(self, message: str) -> str: + message = self.generic_client.strip_code_snippets(message) + return self._CLINE_FILE_REGEX.sub("", message) + + +class ClientFactory: + """Secure factory with updated client mappings""" + + _implementations: Dict[ClientType, Type[ClientInterface]] = { + ClientType.GENERIC: GenericClient, + ClientType.CLINE: ClineClient, + ClientType.KODU: ClineClient, + } + + @classmethod + def create(cls, client_type: ClientType) -> ClientInterface: + return cls._implementations.get(client_type, GenericClient)() diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index e22874a6..12a6d7d1 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,6 +5,7 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType +from codegate.clients.interface import ClientFactory from codegate.db.models import AlertSeverity from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( @@ -22,6 +23,9 @@ # Pre-compiled regex patterns for performance markdown_code_block = re.compile(r"```.*?```", flags=re.DOTALL) markdown_file_listing = re.compile(r"⋮...*?⋮...\n\n", flags=re.DOTALL) +cline_file_listing = re.compile( + r"(?i)<\s*file_content\s*[^>]*>.*?", flags=re.DOTALL +) environment_details = re.compile(r".*?", flags=re.DOTALL) @@ -112,12 +116,11 @@ async def process( # noqa: C901 # Remove code snippets and file listing from the user messages and search for bad packages # in the rest of the user query/messsages - user_messages = markdown_code_block.sub("", user_message) - user_messages = markdown_file_listing.sub("", user_messages) - user_messages = environment_details.sub("", user_messages) + client_if = ClientFactory.create(context.client) + non_code_user_message = client_if.strip_code_snippets(user_message) # split messages into double newlines, to avoid passing so many content in the search - split_messages = re.split(r"|\n|\\n", user_messages) + split_messages = re.split(r"|\n|\\n", non_code_user_message) collected_bad_packages = [] for item_message in filter(None, map(str.strip, split_messages)): # Vector search to find bad packages