Skip to content

Properly parse conversations into alerts #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,15 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

try:
alerts = await dbreader.get_alerts_with_prompt_and_output(ws.id)
return await v1_processing.parse_get_alert_conversation(alerts)
alerts = await dbreader.get_alerts_by_workspace(ws.id)
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs)
except Exception:
logger.exception("Error while getting alerts and messages")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -223,12 +226,15 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

try:
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
return await v1_processing.parse_messages_in_conversations(prompts_outputs)
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
return conversations
except Exception:
logger.exception("Error while getting messages")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down
90 changes: 47 additions & 43 deletions src/codegate/api/v1_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import re
from collections import defaultdict
from typing import AsyncGenerator, List, Optional, Union
from typing import AsyncGenerator, Dict, List, Optional, Tuple

import requests
import structlog
Expand All @@ -16,7 +16,7 @@
QuestionAnswer,
)
from codegate.db.connection import alert_queue
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
from codegate.db.models import Alert, GetPromptWithOutputsRow

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -144,9 +144,7 @@ def _parse_single_output(single_output: dict) -> str:
return full_output_message


async def _get_question_answer(
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
) -> Optional[PartialQuestionAnswer]:
async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation

Expand Down Expand Up @@ -195,6 +193,11 @@ def parse_question_answer(input_text: str) -> str:
return input_text


def _clean_secrets_from_message(message: str) -> str:
pattern = re.compile(r"REDACTED<(\$?[^>]+)>")
return pattern.sub("REDACTED_SECRET", message)


def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
"""
A PartialQuestion is an object that contains several user messages provided from a
Expand All @@ -210,6 +213,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
- Leave any unpaired pq by itself.
- Finally, sort the resulting groups by the earliest timestamp in each group.
"""
# 0) Clean secrets from messages
for pq in pq_list:
pq.messages = [_clean_secrets_from_message(msg) for msg in pq.messages]

# 1) Sort by length of messages descending (largest/most-complete first),
# then by timestamp ascending for stable processing.
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))
Expand All @@ -224,7 +231,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia

# Find all potential subsets of 'sup' that are not yet used
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
possible_subsets = []
possible_subsets: List[PartialQuestions] = []
for sub in pq_list_sorted:
if sub.message_id == sup.message_id:
continue
Expand Down Expand Up @@ -281,7 +288,7 @@ def _get_question_answer_from_partial(

async def match_conversations(
partial_question_answers: List[Optional[PartialQuestionAnswer]],
) -> List[Conversation]:
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
"""
Match partial conversations to form a complete conversation.
"""
Expand All @@ -294,45 +301,47 @@ async def match_conversations(

# Create the conversation objects
conversations = []
map_q_id_to_conversation = {}
for group in grouped_partial_questions:
questions_answers = []
questions_answers: List[QuestionAnswer] = []
first_partial_qa = None
for partial_question in sorted(group, key=lambda x: x.timestamp):
# Partial questions don't contain the answer, so we need to find the corresponding
# valid partial question answer
selected_partial_qa = None
for partial_qa in valid_partial_qas:
if partial_question.message_id == partial_qa.partial_questions.message_id:
selected_partial_qa = partial_qa
break

# check if we have an answer, otherwise do not add it
if selected_partial_qa.answer is not None:
# if we don't have a first question, set it
# check if we have a question and answer, otherwise do not add it
if selected_partial_qa and selected_partial_qa.answer is not None:
# if we don't have a first question, set it. We will use it
# to set the conversation timestamp and provider
first_partial_qa = first_partial_qa or selected_partial_qa
question_answer = _get_question_answer_from_partial(selected_partial_qa)
question_answer.question.message = parse_question_answer(
question_answer.question.message
)
questions_answers.append(question_answer)
qa = _get_question_answer_from_partial(selected_partial_qa)
qa.question.message = parse_question_answer(qa.question.message)
questions_answers.append(qa)

# only add conversation if we have some answers
if len(questions_answers) > 0 and first_partial_qa is not None:
conversations.append(
Conversation(
question_answers=questions_answers,
provider=first_partial_qa.partial_questions.provider,
type=first_partial_qa.partial_questions.type,
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
)
conversation = Conversation(
question_answers=questions_answers,
provider=first_partial_qa.partial_questions.provider,
type=first_partial_qa.partial_questions.type,
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
)
for qa in questions_answers:
map_q_id_to_conversation[qa.question.message_id] = conversation
conversations.append(conversation)

return conversations
return conversations, map_q_id_to_conversation


async def parse_messages_in_conversations(
prompts_outputs: List[GetPromptWithOutputsRow],
) -> List[Conversation]:
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
"""
Get all the messages from the database and return them as a list of conversations.
"""
Expand All @@ -342,31 +351,21 @@ async def parse_messages_in_conversations(
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
partial_question_answers = [task.result() for task in tasks]

conversations = await match_conversations(partial_question_answers)
return conversations
conversations, map_q_id_to_conversation = await match_conversations(partial_question_answers)
return conversations, map_q_id_to_conversation


async def parse_row_alert_conversation(
row: GetAlertsWithPromptAndOutputRow,
row: Alert, map_q_id_to_conversation: Dict[str, Conversation]
) -> Optional[AlertConversation]:
"""
Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation

The row contains the raw request and output strings from the pipeline.
"""
partial_qa = await _get_question_answer(row)
if not partial_qa:
conversation = map_q_id_to_conversation.get(row.prompt_id)
if conversation is None:
return None

question_answer = _get_question_answer_from_partial(partial_qa)

conversation = Conversation(
question_answers=[question_answer],
provider=row.provider,
type=row.type,
chat_id=row.id,
conversation_timestamp=row.timestamp,
)
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
trigger_string = None
if row.trigger_string:
Expand All @@ -387,14 +386,19 @@ async def parse_row_alert_conversation(


async def parse_get_alert_conversation(
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
alerts: List[Alert],
prompts_outputs: List[GetPromptWithOutputsRow],
) -> List[AlertConversation]:
"""
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
AlertConversation

The rows contain the raw request and output strings from the pipeline.
"""
_, map_q_id_to_conversation = await parse_messages_in_conversations(prompts_outputs)
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
tasks = [
tg.create_task(parse_row_alert_conversation(row, map_q_id_to_conversation))
for row in alerts
]
return [task.result() for task in tasks if task.result() is not None]
19 changes: 4 additions & 15 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from codegate.db.models import (
ActiveWorkspace,
Alert,
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
Output,
Expand Down Expand Up @@ -430,9 +429,7 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO
)
return prompts

async def get_alerts_with_prompt_and_output(
self, workspace_id: str
) -> List[GetAlertsWithPromptAndOutputRow]:
async def get_alerts_by_workspace(self, workspace_id: str) -> List[Alert]:
sql = text(
"""
SELECT
Expand All @@ -442,24 +439,16 @@ async def get_alerts_with_prompt_and_output(
a.trigger_string,
a.trigger_type,
a.trigger_category,
a.timestamp,
p.timestamp as prompt_timestamp,
p.provider,
p.request,
p.type,
o.id as output_id,
o.output,
o.timestamp as output_timestamp
a.timestamp
FROM alerts a
LEFT JOIN prompts p ON p.id = a.prompt_id
LEFT JOIN outputs o ON p.id = o.prompt_id
INNER JOIN prompts p ON p.id = a.prompt_id
WHERE p.workspace_id = :workspace_id
ORDER BY a.timestamp DESC
"""
)
conditions = {"workspace_id": workspace_id}
prompts = await self._exec_select_conditions_to_pydantic(
GetAlertsWithPromptAndOutputRow, sql, conditions, should_raise=True
Alert, sql, conditions, should_raise=True
)
return prompts

Expand Down
17 changes: 0 additions & 17 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,6 @@ class Session(BaseModel):
# Models for select queries


class GetAlertsWithPromptAndOutputRow(BaseModel):
id: Any
prompt_id: Any
code_snippet: Optional[Any]
trigger_string: Optional[Any]
trigger_type: Any
trigger_category: Optional[Any]
timestamp: Any
prompt_timestamp: Optional[Any]
provider: Optional[Any]
request: Optional[Any]
type: Optional[Any]
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]


class GetPromptWithOutputsRow(BaseModel):
id: Any
timestamp: Any
Expand Down
Loading