Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Copilot DB integration. Keep DB objects in context to record at the end. #331

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,39 @@ async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
logger.warning(f"Error parsing output: {output_str}. {e}")
return None, None

output_message = ""
def _parse_single_output(single_output: dict) -> str:
single_chat_id = single_output.get("id")
single_output_message = ""
for choice in single_output.get("choices", []):
if not isinstance(choice, dict):
continue
content_dict = choice.get("delta", {}) or choice.get("message", {})
single_output_message += content_dict.get("content", "")
return single_output_message, single_chat_id

full_output_message = ""
chat_id = None
if isinstance(output, list):
for output_chunk in output:
if not isinstance(output_chunk, dict):
continue
chat_id = chat_id or output_chunk.get("id")
for choice in output_chunk.get("choices", []):
if not isinstance(choice, dict):
continue
delta_dict = choice.get("delta", {})
output_message += delta_dict.get("content", "")
output_message, output_chat_id = "", None
if isinstance(output_chunk, dict):
output_message, output_chat_id = _parse_single_output(output_chunk)
elif isinstance(output_chunk, str):
try:
output_decoded = json.loads(output_chunk)
output_message, output_chat_id = _parse_single_output(output_decoded)
except Exception:
logger.error(f"Error reading chunk: {output_chunk}")
else:
logger.warning(
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
)
chat_id = chat_id or output_chat_id
full_output_message += output_message
elif isinstance(output, dict):
chat_id = chat_id or output.get("id")
for choice in output.get("choices", []):
if not isinstance(choice, dict):
continue
output_message += choice.get("message", {}).get("content", "")
full_output_message, chat_id = _parse_single_output(output)

return output_message, chat_id
return full_output_message, chat_id


async def _get_question_answer(
Expand All @@ -124,19 +137,23 @@ async def _get_question_answer(
output_msg_str, chat_id = output_task.result()

# If we couldn't parse the request or output, return None
if not request_msg_str or not output_msg_str or not chat_id:
if not request_msg_str:
return None, None

request_message = ChatMessage(
message=request_msg_str,
timestamp=row.timestamp,
message_id=row.id,
)
output_message = ChatMessage(
message=output_msg_str,
timestamp=row.output_timestamp,
message_id=row.output_id,
)
if output_msg_str:
output_message = ChatMessage(
message=output_msg_str,
timestamp=row.output_timestamp,
message_id=row.output_id,
)
else:
output_message = None
chat_id = row.id
return QuestionAnswer(question=request_message, answer=output_message), chat_id


Expand Down
2 changes: 1 addition & 1 deletion src/codegate/dashboard/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class QuestionAnswer(BaseModel):
"""

question: ChatMessage
answer: ChatMessage
answer: Optional[ChatMessage]


class PartialConversation(BaseModel):
Expand Down
135 changes: 52 additions & 83 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import asyncio
import copy
import datetime
import json
import uuid
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, List, Optional
from typing import List, Optional

import structlog
from litellm import ChatCompletionRequest, ModelResponse
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
Expand All @@ -18,6 +14,7 @@
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
)
from codegate.pipeline.base import PipelineContext

logger = structlog.get_logger("codegate")
alert_queue = asyncio.Queue()
Expand Down Expand Up @@ -103,97 +100,51 @@ async def _insert_pydantic_model(
logger.error(f"Failed to insert model: {model}.", error=str(e))
return None

async def record_request(
self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str
) -> Optional[Prompt]:
request_str = None
if isinstance(normalized_request, BaseModel):
request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True)
else:
try:
request_str = json.dumps(normalized_request)
except Exception as e:
logger.error(f"Failed to serialize output: {normalized_request}", error=str(e))

if request_str is None:
logger.warning("No request found to record.")
return

# Create a new prompt record
prompt_params = Prompt(
id=str(uuid.uuid4()), # Generate a new UUID for the prompt
timestamp=datetime.datetime.now(datetime.timezone.utc),
provider=provider_str,
type="fim" if is_fim_request else "chat",
request=request_str,
)
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
return None
sql = text(
"""
INSERT INTO prompts (id, timestamp, provider, request, type)
VALUES (:id, :timestamp, :provider, :request, :type)
RETURNING *
"""
)
return await self._insert_pydantic_model(prompt_params, sql)

async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]:
output_params = Output(
id=str(uuid.uuid4()),
prompt_id=prompt.id,
timestamp=datetime.datetime.now(datetime.timezone.utc),
output=output_str,
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
logger.debug(f"Recorded request: {recorded_request}")
return recorded_request

async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
if not outputs:
return

first_output = outputs[0]
# Create a single entry on DB but encode all of the chunks in the stream as a list
# of JSON objects in the field `output`
output_db = Output(
id=first_output.id,
prompt_id=first_output.prompt_id,
timestamp=first_output.timestamp,
output=first_output.output,
)
full_outputs = []
# Just store the model respnses in the list of JSON objects.
for output in outputs:
full_outputs.append(output.output)
output_db.output = json.dumps(full_outputs)

sql = text(
"""
INSERT INTO outputs (id, prompt_id, timestamp, output)
VALUES (:id, :prompt_id, :timestamp, :output)
RETURNING *
"""
)
return await self._insert_pydantic_model(output_params, sql)

async def record_output_stream(
self, prompt: Prompt, model_response: AsyncIterator
) -> AsyncGenerator:
output_chunks = []
async for chunk in model_response:
if isinstance(chunk, BaseModel):
chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True)
output_chunks.append(chunk_to_record)
elif isinstance(chunk, dict):
output_chunks.append(copy.deepcopy(chunk))
else:
output_chunks.append({"chunk": str(chunk)})
yield chunk

if output_chunks:
# Record the output chunks
output_str = json.dumps(output_chunks)
await self._record_output(prompt, output_str)

async def record_output_non_stream(
self, prompt: Optional[Prompt], model_response: ModelResponse
) -> Optional[Output]:
if prompt is None:
logger.warning("No prompt found to record output.")
return
recorded_output = await self._insert_pydantic_model(output_db, sql)
logger.debug(f"Recorded output: {recorded_output}")
return recorded_output

output_str = None
if isinstance(model_response, BaseModel):
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
else:
try:
output_str = json.dumps(model_response)
except Exception as e:
logger.error(f"Failed to serialize output: {model_response}", error=str(e))

if output_str is None:
logger.warning("No output found to record.")
return

return await self._record_output(prompt, output_str)

async def record_alerts(self, alerts: List[Alert]) -> None:
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
if not alerts:
return
sql = text(
Expand All @@ -208,15 +159,33 @@ async def record_alerts(self, alerts: List[Alert]) -> None:
"""
)
# We can insert each alert independently in parallel.
alerts_tasks = []
async with asyncio.TaskGroup() as tg:
for alert in alerts:
try:
result = tg.create_task(self._insert_pydantic_model(alert, sql))
if result and alert.trigger_category == "critical":
await alert_queue.put(f"New alert detected: {alert.timestamp}")
alerts_tasks.append(result)
except Exception as e:
logger.error(f"Failed to record alert: {alert}.", error=str(e))
return None

recorded_alerts = []
for alert_coro in alerts_tasks:
alert_result = alert_coro.result()
recorded_alerts.append(alert_result)
if alert_result and alert_result.trigger_category == "critical":
await alert_queue.put(f"New alert detected: {alert.timestamp}")

logger.debug(f"Recorded alerts: {recorded_alerts}")
return recorded_alerts

async def record_context(self, context: PipelineContext) -> None:
logger.info(
f"Recording context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses)
await self.record_alerts(context.alerts_raised)


class DbReader(DbCodeGate):
Expand Down
Loading
Loading