diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 1e519698..31a4a22f 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -202,12 +202,15 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: + logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") try: - alerts = await dbreader.get_alerts_with_prompt_and_output(ws.id) - return await v1_processing.parse_get_alert_conversation(alerts) + alerts = await dbreader.get_alerts_by_workspace(ws.id) + prompts_outputs = await dbreader.get_prompts_with_output(ws.id) + return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs) except Exception: + logger.exception("Error while getting alerts and messages") raise HTTPException(status_code=500, detail="Internal server error") @@ -223,12 +226,15 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: + logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") try: prompts_outputs = await dbreader.get_prompts_with_output(ws.id) - return await v1_processing.parse_messages_in_conversations(prompts_outputs) + conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs) + return conversations except Exception: + logger.exception("Error while getting messages") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/api/v1_processing.py b/src/codegate/api/v1_processing.py index 906584b2..ed2e119a 100644 --- a/src/codegate/api/v1_processing.py +++ b/src/codegate/api/v1_processing.py @@ -2,7 +2,7 @@ import json import re from collections import defaultdict -from typing import AsyncGenerator, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Optional, Tuple import requests import structlog @@ -16,7 +16,7 @@ QuestionAnswer, ) from codegate.db.connection import alert_queue -from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow +from codegate.db.models import Alert, GetPromptWithOutputsRow logger = structlog.get_logger("codegate") @@ -144,9 +144,7 @@ def _parse_single_output(single_output: dict) -> str: return full_output_message -async def _get_question_answer( - row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow] -) -> Optional[PartialQuestionAnswer]: +async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]: """ Parse a row from the get_prompt_with_outputs query and return a PartialConversation @@ -195,6 +193,11 @@ def parse_question_answer(input_text: str) -> str: return input_text +def _clean_secrets_from_message(message: str) -> str: + pattern = re.compile(r"REDACTED<(\$?[^>]+)>") + return pattern.sub("REDACTED_SECRET", message) + + def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]: """ A PartialQuestion is an object that contains several user messages provided from a @@ -210,6 +213,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia - Leave any unpaired pq by itself. - Finally, sort the resulting groups by the earliest timestamp in each group. """ + # 0) Clean secrets from messages + for pq in pq_list: + pq.messages = [_clean_secrets_from_message(msg) for msg in pq.messages] + # 1) Sort by length of messages descending (largest/most-complete first), # then by timestamp ascending for stable processing. pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp)) @@ -224,7 +231,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia # Find all potential subsets of 'sup' that are not yet used # (If sup's messages == sub's messages, that also counts, because sub ⊆ sup) - possible_subsets = [] + possible_subsets: List[PartialQuestions] = [] for sub in pq_list_sorted: if sub.message_id == sup.message_id: continue @@ -281,7 +288,7 @@ def _get_question_answer_from_partial( async def match_conversations( partial_question_answers: List[Optional[PartialQuestionAnswer]], -) -> List[Conversation]: +) -> Tuple[List[Conversation], Dict[str, Conversation]]: """ Match partial conversations to form a complete conversation. """ @@ -294,45 +301,47 @@ async def match_conversations( # Create the conversation objects conversations = [] + map_q_id_to_conversation = {} for group in grouped_partial_questions: - questions_answers = [] + questions_answers: List[QuestionAnswer] = [] first_partial_qa = None for partial_question in sorted(group, key=lambda x: x.timestamp): # Partial questions don't contain the answer, so we need to find the corresponding + # valid partial question answer selected_partial_qa = None for partial_qa in valid_partial_qas: if partial_question.message_id == partial_qa.partial_questions.message_id: selected_partial_qa = partial_qa break - # check if we have an answer, otherwise do not add it - if selected_partial_qa.answer is not None: - # if we don't have a first question, set it + # check if we have a question and answer, otherwise do not add it + if selected_partial_qa and selected_partial_qa.answer is not None: + # if we don't have a first question, set it. We will use it + # to set the conversation timestamp and provider first_partial_qa = first_partial_qa or selected_partial_qa - question_answer = _get_question_answer_from_partial(selected_partial_qa) - question_answer.question.message = parse_question_answer( - question_answer.question.message - ) - questions_answers.append(question_answer) + qa = _get_question_answer_from_partial(selected_partial_qa) + qa.question.message = parse_question_answer(qa.question.message) + questions_answers.append(qa) # only add conversation if we have some answers if len(questions_answers) > 0 and first_partial_qa is not None: - conversations.append( - Conversation( - question_answers=questions_answers, - provider=first_partial_qa.partial_questions.provider, - type=first_partial_qa.partial_questions.type, - chat_id=first_partial_qa.partial_questions.message_id, - conversation_timestamp=first_partial_qa.partial_questions.timestamp, - ) + conversation = Conversation( + question_answers=questions_answers, + provider=first_partial_qa.partial_questions.provider, + type=first_partial_qa.partial_questions.type, + chat_id=first_partial_qa.partial_questions.message_id, + conversation_timestamp=first_partial_qa.partial_questions.timestamp, ) + for qa in questions_answers: + map_q_id_to_conversation[qa.question.message_id] = conversation + conversations.append(conversation) - return conversations + return conversations, map_q_id_to_conversation async def parse_messages_in_conversations( prompts_outputs: List[GetPromptWithOutputsRow], -) -> List[Conversation]: +) -> Tuple[List[Conversation], Dict[str, Conversation]]: """ Get all the messages from the database and return them as a list of conversations. """ @@ -342,31 +351,21 @@ async def parse_messages_in_conversations( tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs] partial_question_answers = [task.result() for task in tasks] - conversations = await match_conversations(partial_question_answers) - return conversations + conversations, map_q_id_to_conversation = await match_conversations(partial_question_answers) + return conversations, map_q_id_to_conversation async def parse_row_alert_conversation( - row: GetAlertsWithPromptAndOutputRow, + row: Alert, map_q_id_to_conversation: Dict[str, Conversation] ) -> Optional[AlertConversation]: """ Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation The row contains the raw request and output strings from the pipeline. """ - partial_qa = await _get_question_answer(row) - if not partial_qa: + conversation = map_q_id_to_conversation.get(row.prompt_id) + if conversation is None: return None - - question_answer = _get_question_answer_from_partial(partial_qa) - - conversation = Conversation( - question_answers=[question_answer], - provider=row.provider, - type=row.type, - chat_id=row.id, - conversation_timestamp=row.timestamp, - ) code_snippet = json.loads(row.code_snippet) if row.code_snippet else None trigger_string = None if row.trigger_string: @@ -387,7 +386,8 @@ async def parse_row_alert_conversation( async def parse_get_alert_conversation( - alerts_conversations: List[GetAlertsWithPromptAndOutputRow], + alerts: List[Alert], + prompts_outputs: List[GetPromptWithOutputsRow], ) -> List[AlertConversation]: """ Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of @@ -395,6 +395,10 @@ async def parse_get_alert_conversation( The rows contain the raw request and output strings from the pipeline. """ + _, map_q_id_to_conversation = await parse_messages_in_conversations(prompts_outputs) async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations] + tasks = [ + tg.create_task(parse_row_alert_conversation(row, map_q_id_to_conversation)) + for row in alerts + ] return [task.result() for task in tasks if task.result() is not None] diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 51b47a09..9820e292 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -17,7 +17,6 @@ from codegate.db.models import ( ActiveWorkspace, Alert, - GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow, GetWorkspaceByNameConditions, Output, @@ -430,9 +429,7 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO ) return prompts - async def get_alerts_with_prompt_and_output( - self, workspace_id: str - ) -> List[GetAlertsWithPromptAndOutputRow]: + async def get_alerts_by_workspace(self, workspace_id: str) -> List[Alert]: sql = text( """ SELECT @@ -442,24 +439,16 @@ async def get_alerts_with_prompt_and_output( a.trigger_string, a.trigger_type, a.trigger_category, - a.timestamp, - p.timestamp as prompt_timestamp, - p.provider, - p.request, - p.type, - o.id as output_id, - o.output, - o.timestamp as output_timestamp + a.timestamp FROM alerts a - LEFT JOIN prompts p ON p.id = a.prompt_id - LEFT JOIN outputs o ON p.id = o.prompt_id + INNER JOIN prompts p ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id ORDER BY a.timestamp DESC """ ) conditions = {"workspace_id": workspace_id} prompts = await self._exec_select_conditions_to_pydantic( - GetAlertsWithPromptAndOutputRow, sql, conditions, should_raise=True + Alert, sql, conditions, should_raise=True ) return prompts diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 366712b0..23cbea5d 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -67,23 +67,6 @@ class Session(BaseModel): # Models for select queries -class GetAlertsWithPromptAndOutputRow(BaseModel): - id: Any - prompt_id: Any - code_snippet: Optional[Any] - trigger_string: Optional[Any] - trigger_type: Any - trigger_category: Optional[Any] - timestamp: Any - prompt_timestamp: Optional[Any] - provider: Optional[Any] - request: Optional[Any] - type: Optional[Any] - output_id: Optional[Any] - output: Optional[Any] - output_timestamp: Optional[Any] - - class GetPromptWithOutputsRow(BaseModel): id: Any timestamp: Any