diff --git a/prompts/default.yaml b/prompts/default.yaml index c5d104e1..843a094f 100644 --- a/prompts/default.yaml +++ b/prompts/default.yaml @@ -24,6 +24,11 @@ lookup_packages: | Your job is to extract any software packages from user's request. A package is a named entity. You MUST RESPOND with a list of packages in JSON FORMAT: {"packages": [pkg1, pkg2, ...]}. +secrets_redacted: | + The files in the context contain sensitive information that has been redacted. Do not warn the user + about any tokens, passwords or similar sensitive information in the context whose value begins with + the string "REDACTED". + # Security-focused prompts security_audit: "You are a security expert conducting a thorough code review. Identify potential security vulnerabilities, suggest improvements, and explain security best practices." diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 351fb253..85aa39e2 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -2,9 +2,10 @@ from typing import Optional import structlog -from litellm import ChatCompletionRequest, ModelResponse +from litellm import ChatCompletionRequest, ChatCompletionSystemMessage, ModelResponse from litellm.types.utils import Delta, StreamingChoices +from codegate.config import Config from codegate.pipeline.base import ( AlertSeverity, PipelineContext, @@ -14,6 +15,7 @@ from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures +from codegate.pipeline.systemmsg import add_or_update_system_message logger = structlog.get_logger("codegate") @@ -197,6 +199,12 @@ async def process( # Store the count in context metadata context.metadata["redacted_secrets_count"] = total_redacted + if total_redacted > 0: + system_message = ChatCompletionSystemMessage( + content=Config.get_config().prompts.secrets_redacted, + role="system", + ) + new_request = add_or_update_system_message(new_request, system_message, context) return PipelineResult(request=new_request, context=context) diff --git a/src/codegate/pipeline/systemmsg.py b/src/codegate/pipeline/systemmsg.py new file mode 100644 index 00000000..f98bec8a --- /dev/null +++ b/src/codegate/pipeline/systemmsg.py @@ -0,0 +1,58 @@ +import json +from typing import Optional + +from litellm import ChatCompletionRequest, ChatCompletionSystemMessage + +from codegate.pipeline.base import PipelineContext + + +def get_existing_system_message(request: ChatCompletionRequest) -> Optional[dict]: + """ + Retrieves the existing system message from the completion request. + + Args: + request: The original completion request. + + Returns: + The existing system message if found, otherwise None. + """ + for message in request.get("messages", []): + if message["role"] == "system": + return message + return None + + +def add_or_update_system_message( + request: ChatCompletionRequest, + system_message: ChatCompletionSystemMessage, + context: PipelineContext, +) -> ChatCompletionRequest: + """ + Adds or updates the system message in the completion request. + + Args: + request: The original completion request. + system_message: The system message to add or update. + context: The pipeline context for adding alerts. + + Returns: + The updated completion request. + """ + new_request = request.copy() + + if "messages" not in new_request: + new_request["messages"] = [] + + request_system_message = get_existing_system_message(new_request) + + if request_system_message is None: + # Add new system message + context.add_alert("add-system-message", trigger_string=json.dumps(system_message)) + new_request["messages"].insert(0, system_message) + else: + # Update existing system message + updated_content = request_system_message["content"] + "\n\n" + system_message["content"] + context.add_alert("update-system-message", trigger_string=updated_content) + request_system_message["content"] = updated_content + + return new_request diff --git a/tests/pipeline/test_systemmsg.py b/tests/pipeline/test_systemmsg.py new file mode 100644 index 00000000..25334f5d --- /dev/null +++ b/tests/pipeline/test_systemmsg.py @@ -0,0 +1,142 @@ +from unittest.mock import Mock + +import pytest + +from codegate.pipeline.base import PipelineContext +from codegate.pipeline.systemmsg import add_or_update_system_message, get_existing_system_message + + +class TestAddOrUpdateSystemMessage: + def test_init_with_system_message(self): + """ + Test creating a system message + """ + test_message = {"role": "system", "content": "Test system prompt"} + context = Mock(spec=PipelineContext) + context.add_alert = Mock() + + request = {"messages": []} + result = add_or_update_system_message(request, test_message, context) + + assert len(result["messages"]) == 1 + assert result["messages"][0]["content"] == test_message["content"] + + @pytest.mark.parametrize( + "request_setup", + [{"messages": [{"role": "user", "content": "Test user message"}]}, {"messages": []}, {}], + ) + def test_system_message_insertion(self, request_setup): + """ + Test system message insertion in various request scenarios + """ + context = Mock(spec=PipelineContext) + context.add_alert = Mock() + + system_message = {"role": "system", "content": "Security analysis system prompt"} + + result = add_or_update_system_message(request_setup, system_message, context) + + assert len(result["messages"]) > 0 + assert result["messages"][0]["role"] == "system" + assert result["messages"][0]["content"] == system_message["content"] + context.add_alert.assert_called_once() + + def test_update_existing_system_message(self): + """ + Test updating an existing system message + """ + existing_system_message = {"role": "system", "content": "Existing system message"} + request = {"messages": [existing_system_message]} + context = Mock(spec=PipelineContext) + context.add_alert = Mock() + + new_system_message = {"role": "system", "content": "Additional system instructions"} + + result = add_or_update_system_message(request, new_system_message, context) + + assert len(result["messages"]) == 1 + expected_content = "Existing system message" + "\n\n" + "Additional system instructions" + + assert result["messages"][0]["content"] == expected_content + context.add_alert.assert_called_once_with( + "update-system-message", trigger_string=expected_content + ) + + @pytest.mark.parametrize( + "edge_case", + [ + None, # No messages + [], # Empty messages list + ], + ) + def test_edge_cases(self, edge_case): + """ + Test edge cases with None or empty message list + """ + request = {"messages": edge_case} if edge_case is not None else {} + context = Mock(spec=PipelineContext) + context.add_alert = Mock() + + system_message = {"role": "system", "content": "Security edge case prompt"} + + result = add_or_update_system_message(request, system_message, context) + + assert len(result["messages"]) == 1 + assert result["messages"][0]["role"] == "system" + assert result["messages"][0]["content"] == system_message["content"] + context.add_alert.assert_called_once() + + +class TestGetExistingSystemMessage: + def test_existing_system_message(self): + """ + Test retrieving an existing system message + """ + system_message = {"role": "system", "content": "Existing system message"} + request = {"messages": [system_message, {"role": "user", "content": "User message"}]} + + result = get_existing_system_message(request) + + assert result == system_message + + def test_no_system_message(self): + """ + Test when there is no system message in the request + """ + request = {"messages": [{"role": "user", "content": "User message"}]} + + result = get_existing_system_message(request) + + assert result is None + + def test_empty_messages(self): + """ + Test when the messages list is empty + """ + request = {"messages": []} + + result = get_existing_system_message(request) + + assert result is None + + def test_no_messages_key(self): + """ + Test when the request has no 'messages' key + """ + request = {} + + result = get_existing_system_message(request) + + assert result is None + + def test_multiple_system_messages(self): + """ + Test when there are multiple system messages, should return the first one + """ + system_message1 = {"role": "system", "content": "First system message"} + system_message2 = {"role": "system", "content": "Second system message"} + request = {"messages": [system_message1, system_message2]} + + result = get_existing_system_message(request) + + assert result == system_message1 diff --git a/tests/test_cli.py b/tests/test_cli.py index d72e562b..25105dcc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -74,7 +74,7 @@ def test_serve_default_options( "port": 8989, "log_level": "INFO", "log_format": "JSON", - "prompts_loaded": 5, # Default prompts are loaded + "prompts_loaded": 6, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -113,7 +113,7 @@ def test_serve_custom_options( "port": 8989, "log_level": "DEBUG", "log_format": "TEXT", - "prompts_loaded": 5, # Default prompts are loaded + "prompts_loaded": 6, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -153,7 +153,7 @@ def test_serve_with_config_file( "port": 8989, "log_level": "DEBUG", "log_format": "JSON", - "prompts_loaded": 5, # Default prompts are loaded + "prompts_loaded": 6, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -205,7 +205,7 @@ def test_serve_priority_resolution( "port": 8080, "log_level": "ERROR", "log_format": "TEXT", - "prompts_loaded": 5, # Default prompts are loaded + "prompts_loaded": 6, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, )