diff --git a/src/codegate/dashboard/post_processing.py b/src/codegate/dashboard/post_processing.py index c2e1059b..9e43033c 100644 --- a/src/codegate/dashboard/post_processing.py +++ b/src/codegate/dashboard/post_processing.py @@ -86,26 +86,39 @@ async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]: logger.warning(f"Error parsing output: {output_str}. {e}") return None, None - output_message = "" + def _parse_single_output(single_output: dict) -> str: + single_chat_id = single_output.get("id") + single_output_message = "" + for choice in single_output.get("choices", []): + if not isinstance(choice, dict): + continue + content_dict = choice.get("delta", {}) or choice.get("message", {}) + single_output_message += content_dict.get("content", "") + return single_output_message, single_chat_id + + full_output_message = "" chat_id = None if isinstance(output, list): for output_chunk in output: - if not isinstance(output_chunk, dict): - continue - chat_id = chat_id or output_chunk.get("id") - for choice in output_chunk.get("choices", []): - if not isinstance(choice, dict): - continue - delta_dict = choice.get("delta", {}) - output_message += delta_dict.get("content", "") + output_message, output_chat_id = "", None + if isinstance(output_chunk, dict): + output_message, output_chat_id = _parse_single_output(output_chunk) + elif isinstance(output_chunk, str): + try: + output_decoded = json.loads(output_chunk) + output_message, output_chat_id = _parse_single_output(output_decoded) + except Exception: + logger.error(f"Error reading chunk: {output_chunk}") + else: + logger.warning( + f"Could not handle output: {output_chunk}", out_type=type(output_chunk) + ) + chat_id = chat_id or output_chat_id + full_output_message += output_message elif isinstance(output, dict): - chat_id = chat_id or output.get("id") - for choice in output.get("choices", []): - if not isinstance(choice, dict): - continue - output_message += choice.get("message", {}).get("content", "") + full_output_message, chat_id = _parse_single_output(output) - return output_message, chat_id + return full_output_message, chat_id async def _get_question_answer( @@ -124,7 +137,7 @@ async def _get_question_answer( output_msg_str, chat_id = output_task.result() # If we couldn't parse the request or output, return None - if not request_msg_str or not output_msg_str or not chat_id: + if not request_msg_str: return None, None request_message = ChatMessage( @@ -132,11 +145,15 @@ async def _get_question_answer( timestamp=row.timestamp, message_id=row.id, ) - output_message = ChatMessage( - message=output_msg_str, - timestamp=row.output_timestamp, - message_id=row.output_id, - ) + if output_msg_str: + output_message = ChatMessage( + message=output_msg_str, + timestamp=row.output_timestamp, + message_id=row.output_id, + ) + else: + output_message = None + chat_id = row.id return QuestionAnswer(question=request_message, answer=output_message), chat_id diff --git a/src/codegate/dashboard/request_models.py b/src/codegate/dashboard/request_models.py index d33e8732..8f13a03c 100644 --- a/src/codegate/dashboard/request_models.py +++ b/src/codegate/dashboard/request_models.py @@ -22,7 +22,7 @@ class QuestionAnswer(BaseModel): """ question: ChatMessage - answer: ChatMessage + answer: Optional[ChatMessage] class PartialConversation(BaseModel): diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index be9f9994..5dbe14fa 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,13 +1,9 @@ import asyncio -import copy -import datetime import json -import uuid from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, List, Optional +from typing import List, Optional import structlog -from litellm import ChatCompletionRequest, ModelResponse from pydantic import BaseModel from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine @@ -18,6 +14,7 @@ GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow, ) +from codegate.pipeline.base import PipelineContext logger = structlog.get_logger("codegate") alert_queue = asyncio.Queue() @@ -103,30 +100,9 @@ async def _insert_pydantic_model( logger.error(f"Failed to insert model: {model}.", error=str(e)) return None - async def record_request( - self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str - ) -> Optional[Prompt]: - request_str = None - if isinstance(normalized_request, BaseModel): - request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True) - else: - try: - request_str = json.dumps(normalized_request) - except Exception as e: - logger.error(f"Failed to serialize output: {normalized_request}", error=str(e)) - - if request_str is None: - logger.warning("No request found to record.") - return - - # Create a new prompt record - prompt_params = Prompt( - id=str(uuid.uuid4()), # Generate a new UUID for the prompt - timestamp=datetime.datetime.now(datetime.timezone.utc), - provider=provider_str, - type="fim" if is_fim_request else "chat", - request=request_str, - ) + async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: + if prompt_params is None: + return None sql = text( """ INSERT INTO prompts (id, timestamp, provider, request, type) @@ -134,15 +110,29 @@ async def record_request( RETURNING * """ ) - return await self._insert_pydantic_model(prompt_params, sql) - - async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]: - output_params = Output( - id=str(uuid.uuid4()), - prompt_id=prompt.id, - timestamp=datetime.datetime.now(datetime.timezone.utc), - output=output_str, + recorded_request = await self._insert_pydantic_model(prompt_params, sql) + logger.debug(f"Recorded request: {recorded_request}") + return recorded_request + + async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: + if not outputs: + return + + first_output = outputs[0] + # Create a single entry on DB but encode all of the chunks in the stream as a list + # of JSON objects in the field `output` + output_db = Output( + id=first_output.id, + prompt_id=first_output.prompt_id, + timestamp=first_output.timestamp, + output=first_output.output, ) + full_outputs = [] + # Just store the model respnses in the list of JSON objects. + for output in outputs: + full_outputs.append(output.output) + output_db.output = json.dumps(full_outputs) + sql = text( """ INSERT INTO outputs (id, prompt_id, timestamp, output) @@ -150,50 +140,11 @@ async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Outp RETURNING * """ ) - return await self._insert_pydantic_model(output_params, sql) - - async def record_output_stream( - self, prompt: Prompt, model_response: AsyncIterator - ) -> AsyncGenerator: - output_chunks = [] - async for chunk in model_response: - if isinstance(chunk, BaseModel): - chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True) - output_chunks.append(chunk_to_record) - elif isinstance(chunk, dict): - output_chunks.append(copy.deepcopy(chunk)) - else: - output_chunks.append({"chunk": str(chunk)}) - yield chunk - - if output_chunks: - # Record the output chunks - output_str = json.dumps(output_chunks) - await self._record_output(prompt, output_str) - - async def record_output_non_stream( - self, prompt: Optional[Prompt], model_response: ModelResponse - ) -> Optional[Output]: - if prompt is None: - logger.warning("No prompt found to record output.") - return + recorded_output = await self._insert_pydantic_model(output_db, sql) + logger.debug(f"Recorded output: {recorded_output}") + return recorded_output - output_str = None - if isinstance(model_response, BaseModel): - output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True) - else: - try: - output_str = json.dumps(model_response) - except Exception as e: - logger.error(f"Failed to serialize output: {model_response}", error=str(e)) - - if output_str is None: - logger.warning("No output found to record.") - return - - return await self._record_output(prompt, output_str) - - async def record_alerts(self, alerts: List[Alert]) -> None: + async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: if not alerts: return sql = text( @@ -208,15 +159,33 @@ async def record_alerts(self, alerts: List[Alert]) -> None: """ ) # We can insert each alert independently in parallel. + alerts_tasks = [] async with asyncio.TaskGroup() as tg: for alert in alerts: try: result = tg.create_task(self._insert_pydantic_model(alert, sql)) - if result and alert.trigger_category == "critical": - await alert_queue.put(f"New alert detected: {alert.timestamp}") + alerts_tasks.append(result) except Exception as e: logger.error(f"Failed to record alert: {alert}.", error=str(e)) - return None + + recorded_alerts = [] + for alert_coro in alerts_tasks: + alert_result = alert_coro.result() + recorded_alerts.append(alert_result) + if alert_result and alert_result.trigger_category == "critical": + await alert_queue.put(f"New alert detected: {alert.timestamp}") + + logger.debug(f"Recorded alerts: {recorded_alerts}") + return recorded_alerts + + async def record_context(self, context: PipelineContext) -> None: + logger.info( + f"Recording context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) + await self.record_request(context.input_request) + await self.record_outputs(context.output_responses) + await self.record_alerts(context.alerts_raised) class DbReader(DbCodeGate): diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 720fc2af..1b457fd1 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -8,9 +8,10 @@ from typing import Any, Dict, List, Optional import structlog -from litellm import ChatCompletionRequest +from litellm import ChatCompletionRequest, ModelResponse +from pydantic import BaseModel -from codegate.db.models import Alert +from codegate.db.models import Alert, Output, Prompt from codegate.pipeline.secrets.manager import SecretsManager logger = structlog.get_logger("codegate") @@ -73,6 +74,9 @@ class PipelineContext: metadata: Dict[str, Any] = field(default_factory=dict) sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None) alerts_raised: List[Alert] = field(default_factory=list) + prompt_id: Optional[str] = field(default_factory=lambda: None) + input_request: Optional[Prompt] = field(default_factory=lambda: None) + output_responses: List[Output] = field(default_factory=list) def add_code_snippet(self, snippet: CodeSnippet): self.code_snippets.append(snippet) @@ -90,9 +94,8 @@ def add_alert( """ Add an alert to the pipeline step alerts_raised. """ - if not self.metadata.get("prompt_id"): - logger.warning("No prompt_id found in context. Alert will not be created") - return + if self.prompt_id is None: + self.prompt_id = str(uuid.uuid4()) if not code_snippet and not trigger_string: logger.warning("No code snippet or trigger string provided for alert. Will not create") @@ -103,7 +106,7 @@ def add_alert( self.alerts_raised.append( Alert( id=str(uuid.uuid4()), - prompt_id=self.metadata["prompt_id"], + prompt_id=self.prompt_id, code_snippet=code_snippet_str, trigger_string=trigger_string, trigger_type=step_name, @@ -111,6 +114,51 @@ def add_alert( timestamp=datetime.datetime.now(datetime.timezone.utc), ) ) + logger.debug(f"Added alert to context: {self.alerts_raised[-1]}") + + def add_input_request( + self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider: str + ) -> None: + try: + if self.prompt_id is None: + self.prompt_id = str(uuid.uuid4()) + + request_str = json.dumps(normalized_request) + + self.input_request = Prompt( + id=self.prompt_id, + timestamp=datetime.datetime.now(datetime.timezone.utc), + provider=provider, + type="fim" if is_fim_request else "chat", + request=request_str, + ) + logger.debug(f"Added input request to context: {self.input_request}") + except Exception as e: + logger.warning(f"Failed to serialize input request: {normalized_request}", error=str(e)) + + def add_output(self, model_response: ModelResponse) -> None: + try: + if self.prompt_id is None: + logger.warning(f"Tried to record output without response: {model_response}") + return + + if isinstance(model_response, BaseModel): + output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True) + else: + output_str = json.dumps(model_response) + + self.output_responses.append( + Output( + id=str(uuid.uuid4()), + prompt_id=self.prompt_id, + timestamp=datetime.datetime.now(datetime.timezone.utc), + output=output_str, + ) + ) + logger.debug(f"Added output to context: {self.output_responses[-1]}") + except Exception as e: + logger.error(f"Failed to serialize output: {model_response}", error=str(e)) + return @dataclass @@ -212,20 +260,23 @@ async def process( class InputPipelineInstance: - def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager): + def __init__( + self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool + ): self.pipeline_steps = pipeline_steps self.secret_manager = secret_manager + self.is_fim = is_fim self.context = PipelineContext() async def process_request( self, request: ChatCompletionRequest, provider: str, - prompt_id: str, model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, extra_headers: Optional[Dict[str, str]] = None, + is_copilot: bool = False, ) -> PipelineResult: """Process a request through all pipeline steps""" self.context.sensitive = PipelineSensitiveData( @@ -236,16 +287,22 @@ async def process_request( provider=provider, api_base=api_base, ) - self.context.metadata["prompt_id"] = prompt_id self.context.metadata["extra_headers"] = extra_headers current_request = request + # For Copilot provider=openai. Use a flag to not clash with other places that may use that. + provider_db = "copilot" if is_copilot else provider + for step in self.pipeline_steps: result = await step.process(current_request, self.context) if result is None: continue if result.shortcuts_processing(): + # Also record the input when shortchutting + self.context.add_input_request( + current_request, is_fim_request=self.is_fim, provider=provider_db + ) return result if result.request is not None: @@ -254,30 +311,37 @@ async def process_request( if result.context is not None: self.context = result.context + # Create the input request at the end so we make sure the secrets are obfuscated + self.context.add_input_request( + current_request, is_fim_request=self.is_fim, provider=provider_db + ) return PipelineResult(request=current_request, context=self.context) class SequentialPipelineProcessor: - def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager): + def __init__( + self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool + ): self.pipeline_steps = pipeline_steps self.secret_manager = secret_manager + self.is_fim = is_fim def create_instance(self) -> InputPipelineInstance: """Create a new pipeline instance for processing a request""" - return InputPipelineInstance(self.pipeline_steps, self.secret_manager) + return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim) async def process_request( self, request: ChatCompletionRequest, provider: str, - prompt_id: str, model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, extra_headers: Optional[Dict[str, str]] = None, + is_copilot: bool = False, ) -> PipelineResult: """Create a new pipeline instance and process the request""" instance = self.create_instance() return await instance.process_request( - request, provider, prompt_id, model, api_key, api_base, extra_headers + request, provider, model, api_key, api_base, extra_headers, is_copilot ) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 038d1412..e2c2f85f 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -32,13 +32,13 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor: SystemPrompt(Config.get_config().prompts.default_chat), CodegateContextRetriever(), ] - return SequentialPipelineProcessor(input_steps, self.secrets_manager) + return SequentialPipelineProcessor(input_steps, self.secrets_manager, is_fim=False) def create_fim_pipeline(self) -> SequentialPipelineProcessor: fim_steps: List[PipelineStep] = [ CodegateSecrets(), ] - return SequentialPipelineProcessor(fim_steps, self.secrets_manager) + return SequentialPipelineProcessor(fim_steps, self.secrets_manager, is_fim=True) def create_output_pipeline(self) -> OutputPipelineProcessor: output_steps: List[OutputPipelineStep] = [ diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 70a25601..ad4f14b9 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -6,6 +6,7 @@ from litellm import ModelResponse from litellm.types.utils import Delta, StreamingChoices +from codegate.db.connection import DbRecorder from codegate.pipeline.base import CodeSnippet, PipelineContext logger = structlog.get_logger("codegate") @@ -75,6 +76,7 @@ def __init__( self, pipeline_steps: list[OutputPipelineStep], input_context: Optional[PipelineContext] = None, + db_recorder: Optional[DbRecorder] = None, ): self._input_context = input_context self._pipeline_steps = pipeline_steps @@ -83,6 +85,10 @@ def __init__( # the remaining content in the buffer when the stream ends, we need # to store the parameters like model, timestamp, etc. self._buffered_chunk = None + if not db_recorder: + self._db_recorder = DbRecorder() + else: + self._db_recorder = db_recorder def _buffer_chunk(self, chunk: ModelResponse) -> None: """ @@ -105,6 +111,11 @@ def _store_chunk_content(self, chunk: ModelResponse) -> None: if choice.delta is not None and choice.delta.content is not None: self._context.processed_content.append(choice.delta.content) + async def _record_to_db(self): + if self._input_context and not self._input_context.metadata.get("stored_in_db", False): + await self._db_recorder.record_context(self._input_context) + self._input_context.metadata["stored_in_db"] = True + async def process_stream( self, stream: AsyncIterator[ModelResponse] ) -> AsyncIterator[ModelResponse]: @@ -115,6 +126,7 @@ async def process_stream( async for chunk in stream: # Store chunk content in buffer self._buffer_chunk(chunk) + self._input_context.add_output(chunk) # Process chunk through each step of the pipeline current_chunks = [chunk] @@ -132,6 +144,13 @@ async def process_stream( current_chunks = processed_chunks + # **Needed for Copilot**. This is a hacky way of recording in DB the context + # when we see the last chunk. Ideally this should be done in a `finally` or + # `StopAsyncIteration` but Copilot streams in an infite while loop so is not + # possible + if len(chunk.choices) > 0 and chunk.choices[0].get("finish_reason", "") == "stop": + await self._record_to_db() + # Yield all processed chunks for c in current_chunks: self._store_chunk_content(c) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index a87624e5..67bc0386 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -108,9 +108,6 @@ def _redact_text( # Get the full value full_value = text[start:end] - context.add_alert( - self.name, trigger_string=full_value, severity_category=AlertSeverity.CRITICAL - ) absolute_matches.append((start, end, match._replace(value=full_value))) # Sort matches in reverse order to replace from end to start @@ -134,6 +131,10 @@ def _redact_text( # Create the replacement string replacement = f"REDACTED<${encrypted_value}>" + # Store the protected text in DB. + context.add_alert( + self.name, trigger_string=replacement, severity_category=AlertSeverity.CRITICAL + ) # Replace the secret in the text protected_text[start:end] = replacement @@ -288,9 +289,11 @@ async def process_chunk( # If value not found, leave as is original_value = match.group(0) # Keep the REDACTED marker - # Unredact the content, post an alert and return the chunk + # Post an alert with the redacted content + input_context.add_alert(self.name, trigger_string=encrypted_value) + + # Unredact the content and return the chunk unredacted_content = buffered_content[: match.start()] + original_value + remaining - input_context.add_alert(self.name, trigger_string=unredacted_content) # Return the unredacted content up to this point chunk.choices = [ StreamingChoices( diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index d9d04566..a0350737 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -1,4 +1,3 @@ -import asyncio from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Callable, Dict, Optional, Union @@ -109,7 +108,6 @@ async def _run_input_pipeline( api_key: Optional[str], api_base: Optional[str], is_fim_request: bool, - prompt_id: str, ) -> PipelineResult: # Decide which pipeline processor to use if is_fim_request: @@ -125,7 +123,6 @@ async def _run_input_pipeline( result = await pipeline_processor.process_request( request=normalized_request, provider=self.provider_route_name, - prompt_id=prompt_id, model=normalized_request.get("model"), api_key=api_key, api_base=api_base, @@ -194,10 +191,14 @@ async def _cleanup_after_streaming( async for item in stream: yield item finally: - # Ensure sensitive data is cleaned up after the stream is consumed - if context and context.sensitive: - context.sensitive.secure_cleanup() - await self._db_recorder.record_alerts(context.alerts_raised) + if context: + # Record to DB the objects captured during the stream + if not context.metadata.get("stored_in_db", False): + await self._db_recorder.record_context(context) + context.metadata["stored_in_db"] = True + # Ensure sensitive data is cleaned up after the stream is consumed + if context.sensitive: + context.sensitive.secure_cleanup() async def complete( self, data: Dict, api_key: Optional[str], is_fim_request: bool @@ -215,22 +216,15 @@ async def complete( """ normalized_request = self._input_normalizer.normalize(data) streaming = normalized_request.get("stream", False) - prompt_db = await self._db_recorder.record_request( - normalized_request, is_fim_request, self.provider_route_name - ) - - prompt_db_id = prompt_db.id if prompt_db is not None else None input_pipeline_result = await self._run_input_pipeline( normalized_request, api_key, data.get("base_url"), is_fim_request, - prompt_id=prompt_db_id, ) if input_pipeline_result.response: - await self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised) return await self._pipeline_response_formatter.handle_pipeline_response( - input_pipeline_result.response, streaming, prompt_db=prompt_db + input_pipeline_result.response, streaming, context=input_pipeline_result.context ) provider_request = self._input_normalizer.denormalize(input_pipeline_result.request) @@ -246,18 +240,9 @@ async def complete( if not streaming: normalized_response = self._output_normalizer.normalize(model_response) pipeline_output = self._run_output_pipeline(normalized_response) - # Record the output and alerts in the database can be done in parallel - async with asyncio.TaskGroup() as tg: - tg.create_task( - self._db_recorder.record_output_non_stream(prompt_db, model_response) - ) - if input_pipeline_result and input_pipeline_result.context: - tg.create_task( - self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised) - ) + await self._db_recorder.record_context(input_pipeline_result.context) return self._output_normalizer.denormalize(pipeline_output) - model_response = self._db_recorder.record_output_stream(prompt_db, model_response) pipeline_output_stream = await self._run_output_stream_pipeline( input_pipeline_result.context, model_response, is_fim_request=is_fim_request ) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index d20dee86..0f356770 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -1,11 +1,12 @@ import json from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, Tuple import structlog from litellm.types.llms.openai import ChatCompletionRequest -from codegate.pipeline.base import PipelineContext +from codegate.pipeline.base import PipelineContext, SequentialPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.normalizer.completion import CompletionNormalizer logger = structlog.get_logger("codegate") @@ -18,7 +19,7 @@ class CopilotPipeline(ABC): factory to create the pipeline itself and run the request """ - def __init__(self, pipeline_factory): + def __init__(self, pipeline_factory: PipelineFactory): self.pipeline_factory = pipeline_factory self.normalizer = self._create_normalizer() self.provider_name = "openai" @@ -29,7 +30,7 @@ def _create_normalizer(self): pass @abstractmethod - def create_pipeline(self): + def create_pipeline(self) -> SequentialPipelineProcessor: """Each strategy defines which pipeline to create""" pass @@ -63,7 +64,7 @@ def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]: return copilot_headers - async def process_body(self, headers: list[str], body: bytes) -> (bytes, PipelineContext): + async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]: """Common processing logic for all strategies""" try: normalized_body = self.normalizer.normalize(body) @@ -80,11 +81,11 @@ async def process_body(self, headers: list[str], body: bytes) -> (bytes, Pipelin result = await pipeline.process_request( request=normalized_body, provider=self.provider_name, - prompt_id=self._request_id(headers), model=normalized_body.get("model", "gpt-4o-mini"), api_key=headers_dict.get("authorization", "").replace("Bearer ", ""), api_base="https://" + headers_dict.get("host", ""), extra_headers=CopilotPipeline._get_copilot_headers(headers_dict), + is_copilot=True, ) if result.request: @@ -142,7 +143,7 @@ class CopilotFimPipeline(CopilotPipeline): def _create_normalizer(self): return CopilotFimNormalizer() - def create_pipeline(self): + def create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_fim_pipeline() @@ -155,5 +156,5 @@ class CopilotChatPipeline(CopilotPipeline): def _create_normalizer(self): return CopilotChatNormalizer() - def create_pipeline(self): + def create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_input_pipeline() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 33eb823d..7c484d1c 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -142,7 +142,7 @@ async def _body_through_pipeline( path: str, headers: list[str], body: bytes, - ) -> (bytes, PipelineContext): + ) -> Tuple[bytes, PipelineContext]: logger.debug(f"Processing body through pipeline: {len(body)} bytes") strategy = self._select_pipeline(method, path) if strategy is None: diff --git a/src/codegate/providers/formatting/input_pipeline.py b/src/codegate/providers/formatting/input_pipeline.py index 00a5bdef..ce28e7a7 100644 --- a/src/codegate/providers/formatting/input_pipeline.py +++ b/src/codegate/providers/formatting/input_pipeline.py @@ -5,8 +5,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.db.connection import DbRecorder -from codegate.db.models import Prompt -from codegate.pipeline.base import PipelineResponse +from codegate.pipeline.base import PipelineContext, PipelineResponse from codegate.providers.normalizer.base import ModelOutputNormalizer @@ -67,6 +66,7 @@ async def _convert_to_stream( content: str, step_name: str, model: str, + context: PipelineContext, ) -> AsyncIterator[ModelResponse]: """ Converts a single completion response, provided by our pipeline as a shortcut @@ -89,8 +89,21 @@ def __init__( self._output_normalizer = output_normalizer self._db_recorder = db_recorder + async def _cleanup_after_streaming( + self, stream: AsyncIterator[ModelResponse], context: PipelineContext + ) -> AsyncIterator[ModelResponse]: + """Wraps the stream to ensure cleanup after consumption""" + try: + async for item in stream: + context.add_output(item) + yield item + finally: + if context: + # Record to DB the objects captured during the stream + await self._db_recorder.record_context(context) + async def handle_pipeline_response( - self, pipeline_response: PipelineResponse, streaming: bool, prompt_db: Prompt + self, pipeline_response: PipelineResponse, streaming: bool, context: PipelineContext ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Convert pipeline response to appropriate format based on streaming flag @@ -109,15 +122,14 @@ async def handle_pipeline_response( if not streaming: # If we're not streaming, we just return the response translated # to the provider-specific format - await self._db_recorder.record_output_non_stream(prompt_db, model_response) + context.add_output(model_response) + await self._db_recorder.record_context(context) return self._output_normalizer.denormalize(model_response) # If we're streaming, we need to convert the response to a stream first # then feed the stream into the completion handler's conversion method model_response_stream = _convert_to_stream( - pipeline_response.content, pipeline_response.step_name, pipeline_response.model - ) - model_response_stream = self._db_recorder.record_output_stream( - prompt_db, model_response_stream + pipeline_response.content, pipeline_response.step_name, pipeline_response.model, context ) + model_response_stream = self._cleanup_after_streaming(model_response_stream, context) return self._output_normalizer.denormalize_streaming(model_response_stream) diff --git a/tests/dashboard/test_post_processing.py b/tests/dashboard/test_post_processing.py index e7a8bf3f..5c387ed1 100644 --- a/tests/dashboard/test_post_processing.py +++ b/tests/dashboard/test_post_processing.py @@ -95,7 +95,7 @@ async def test_parse_request(request_dict, expected_str): "model": "gpt-4o-mini", "object": "chat.completion.chunk", "system_fingerprint": "fp_0705bf87c0", - "choices": [{"index": 0, "delta": {"content": "User", "role": "assistant"}}], + "choices": [{"index": 0, "delta": {"content": "Hello", "role": "assistant"}}], }, { "id": "chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl", @@ -103,7 +103,7 @@ async def test_parse_request(request_dict, expected_str): "model": "gpt-4o-mini", "object": "chat.completion.chunk", "system_fingerprint": "fp_0705bf87c0", - "choices": [{"index": 0, "delta": {"content": " seeks"}}], + "choices": [{"index": 0, "delta": {"content": " world"}}], }, { "id": "chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl", @@ -114,7 +114,7 @@ async def test_parse_request(request_dict, expected_str): "choices": [{"finish_reason": "stop", "index": 0, "delta": {}}], }, ], - "User seeks", + "Hello world", "chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl", ), ( @@ -150,7 +150,7 @@ async def test_parse_output(output_dict, expected_str, expected_chat_id): @pytest.mark.asyncio @pytest.mark.parametrize("request_msg_str", ["Hello", None]) @pytest.mark.parametrize("output_msg_str", ["Hello, how can I help you?", None]) -@pytest.mark.parametrize("chat_id", ["chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl", None]) +@pytest.mark.parametrize("chat_id", ["chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl"]) @pytest.mark.parametrize( "row", [ @@ -181,12 +181,13 @@ async def test_parse_get_prompt_with_output(request_msg_str, output_msg_str, cha mock_parse_request.assert_called_once() mock_parse_output.assert_called_once() - if any([request_msg_str is None, output_msg_str is None, chat_id is None]): + if request_msg_str is None: assert result is None else: assert result.question_answer.question.message == request_msg_str - assert result.question_answer.answer.message == output_msg_str - assert result.chat_id == chat_id + if output_msg_str is not None: + assert result.question_answer.answer.message == output_msg_str + assert result.chat_id == chat_id assert result.provider == "provider" assert result.type == "chat" assert result.request_timestamp == timestamp_now diff --git a/tests/pipeline/test_output.py b/tests/pipeline/test_output.py index e3d36deb..c700b3cd 100644 --- a/tests/pipeline/test_output.py +++ b/tests/pipeline/test_output.py @@ -1,4 +1,5 @@ from typing import List +from unittest.mock import MagicMock import pytest from litellm import ModelResponse @@ -59,6 +60,15 @@ def create_model_response(content: str, id: str = "test") -> ModelResponse: ) +class MockContext: + + def __init__(self): + self.sensitive = False + + def add_output(self, chunk: ModelResponse): + pass + + class TestOutputPipelineContext: def test_buffer_initialization(self): """Test that buffer is properly initialized""" @@ -84,7 +94,9 @@ class TestOutputPipelineInstance: async def test_single_step_processing(self): """Test processing a stream through a single step""" step = MockOutputPipelineStep("test_step", modify_content=True) - instance = OutputPipelineInstance([step]) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance([step], context, db_recorder) async def mock_stream(): yield create_model_response("Hello") @@ -107,7 +119,9 @@ async def test_multiple_steps_processing(self): MockOutputPipelineStep("step1", modify_content=True), MockOutputPipelineStep("step2", modify_content=True), ] - instance = OutputPipelineInstance(steps) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance(steps, context, db_recorder) async def mock_stream(): yield create_model_response("Hello") @@ -129,7 +143,9 @@ async def test_step_pausing(self): MockOutputPipelineStep("step1", should_pause=True), MockOutputPipelineStep("step2", modify_content=True), ] - instance = OutputPipelineInstance(steps) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance(steps, context, db_recorder) async def mock_stream(): yield create_model_response("he") @@ -184,7 +200,9 @@ async def process_chunk( return [chunk] return [] - instance = OutputPipelineInstance([ReplacementStep()]) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance([ReplacementStep()], context, db_recorder) async def mock_stream(): yield create_model_response("he") @@ -207,7 +225,9 @@ async def mock_stream(): async def test_buffer_processing(self): """Test that content is properly buffered and cleared""" step = MockOutputPipelineStep("test_step") - instance = OutputPipelineInstance([step]) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance([step], context, db_recorder) async def mock_stream(): yield create_model_response("Hello") @@ -227,7 +247,9 @@ async def mock_stream(): async def test_empty_stream(self): """Test handling of an empty stream""" step = MockOutputPipelineStep("test_step") - instance = OutputPipelineInstance([step]) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance([step], context, db_recorder) async def mock_stream(): if False: @@ -260,7 +282,10 @@ async def process_chunk( assert input_context.metadata["test"] == "value" return [chunk] - instance = OutputPipelineInstance([ContextCheckingStep()], input_context=input_context) + db_recorder = MagicMock() + instance = OutputPipelineInstance( + [ContextCheckingStep()], input_context=input_context, db_recorder=db_recorder + ) async def mock_stream(): yield create_model_response("test") @@ -272,7 +297,9 @@ async def mock_stream(): async def test_buffer_flush_on_stream_end(self): """Test that buffer is properly flushed when stream ends""" step = MockOutputPipelineStep("test_step", should_pause=True) - instance = OutputPipelineInstance([step]) + context = MockContext() + db_recorder = MagicMock() + instance = OutputPipelineInstance([step], context, db_recorder) async def mock_stream(): yield create_model_response("Hello")