diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 915c86d0..ba2d24b7 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -78,6 +78,8 @@ class PipelineContext: input_request: Optional[Prompt] = field(default_factory=lambda: None) output_responses: List[Output] = field(default_factory=list) shortcut_response: bool = False + bad_packages_found: bool = False + secrets_found: bool = False def add_code_snippet(self, snippet: CodeSnippet): self.code_snippets.append(snippet) diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 75e5bb8a..78197860 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -80,20 +80,25 @@ async def process( # Generate context string using the searched objects logger.info(f"Adding {len(searched_objects)} packages to the context") - if len(searched_objects) > 0: + # Nothing to do if no bad packages are found + if len(searched_objects) == 0: + return PipelineResult(request=request, context=context) + else: + # Add context for bad packages context_str = self.generate_context_str(searched_objects, context) + context.bad_packages_found = True - last_user_idx = self.get_last_user_message_idx(request) + last_user_idx = self.get_last_user_message_idx(request) - # Make a copy of the request - new_request = request.copy() + # Make a copy of the request + new_request = request.copy() - # Add the context to the last user message - # Format: "Context: {context_str} \n Query: {last user message content}" - message = new_request["messages"][last_user_idx] - context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' - message["content"] = context_msg + # Add the context to the last user message + # Format: "Context: {context_str} \n Query: {last user message content}" + message = new_request["messages"][last_user_idx] + context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' + message["content"] = context_msg - logger.debug("Final context message", context_message=context_msg) + logger.debug("Final context message", context_message=context_msg) - return PipelineResult(request=new_request, context=context) + return PipelineResult(request=new_request, context=context) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index e2c2f85f..7a713332 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -29,8 +29,8 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor: CodegateSecrets(), CodegateVersion(), CodeSnippetExtractor(), - SystemPrompt(Config.get_config().prompts.default_chat), CodegateContextRetriever(), + SystemPrompt(Config.get_config().prompts.default_chat), ] return SequentialPipelineProcessor(input_steps, self.secrets_manager, is_fim=False) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 8b2a2638..5dd22b95 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -288,6 +288,7 @@ async def process( if i > last_assistant_idx: total_redacted += redacted_count + context.secrets_found = total_redacted > 0 logger.info(f"Total secrets redacted since last assistant message: {total_redacted}") # Store the count in context metadata diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index 8a08ae2a..ee7310da 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -32,6 +32,11 @@ async def process( Add system prompt if not present, otherwise prepend codegate system prompt to the existing system prompt """ + + # Nothing to do if no secrets or bad_packages are found + if not (context.secrets_found or context.bad_packages_found): + return PipelineResult(request=request, context=context) + new_request = request.copy() if "messages" not in new_request: