Skip to content

Commit c0385ec

Browse files
Properly parse conversations into alerts
Closes: #420 This PR introduces: - Have full conversations in alerts. Until now we were only showing the message where the alert came from in `*/alerts` - Match the `chat_id` in the alert conversation to the ones returned in `*/messages` - Fixes a grouping issue in which conversations with secrets were not grouped correctly because of the `REDACTED<>` text.
1 parent 9fa5ce2 commit c0385ec

File tree

5 files changed

+61
-84
lines changed

5 files changed

+61
-84
lines changed

src/codegate/api/dashboard/dashboard.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat
5454
try:
5555
active_ws = asyncio.run(wscrud.get_active_workspace())
5656
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output(active_ws.id))
57-
58-
return asyncio.run(parse_messages_in_conversations(prompts_outputs))
57+
conversations, _ = asyncio.run(parse_messages_in_conversations(prompts_outputs))
58+
return conversations
5959
except Exception as e:
6060
logger.error(f"Error getting messages: {str(e)}")
6161
raise HTTPException(status_code=500, detail="Internal server error")
@@ -70,10 +70,9 @@ def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[Al
7070
"""
7171
try:
7272
active_ws = asyncio.run(wscrud.get_active_workspace())
73-
alerts_prompt_output = asyncio.run(
74-
db_reader.get_alerts_with_prompt_and_output(active_ws.id)
75-
)
76-
return asyncio.run(parse_get_alert_conversation(alerts_prompt_output))
73+
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output(active_ws.id))
74+
alerts = asyncio.run(db_reader.get_alerts_by_workspace(active_ws.id))
75+
return asyncio.run(parse_get_alert_conversation(alerts, prompts_outputs))
7776
except Exception as e:
7877
logger.error(f"Error getting alerts: {str(e)}")
7978
raise HTTPException(status_code=500, detail="Internal server error")

src/codegate/api/dashboard/post_processing.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import re
44
from collections import defaultdict
5-
from typing import List, Optional, Union
5+
from typing import Dict, List, Optional, Tuple
66

77
import structlog
88

@@ -14,7 +14,7 @@
1414
PartialQuestions,
1515
QuestionAnswer,
1616
)
17-
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
17+
from codegate.db.models import Alert, GetPromptWithOutputsRow
1818

1919
logger = structlog.get_logger("codegate")
2020

@@ -124,9 +124,7 @@ def _parse_single_output(single_output: dict) -> str:
124124
return full_output_message
125125

126126

127-
async def _get_question_answer(
128-
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
129-
) -> Optional[PartialQuestionAnswer]:
127+
async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]:
130128
"""
131129
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
132130
@@ -175,6 +173,11 @@ def parse_question_answer(input_text: str) -> str:
175173
return input_text
176174

177175

176+
def _clean_secrets_from_message(message: str) -> str:
177+
pattern = re.compile(r"REDACTED<(\$?[^>]+)>")
178+
return pattern.sub("REDACTED_SECRET", message)
179+
180+
178181
def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
179182
"""
180183
A PartialQuestion is an object that contains several user messages provided from a
@@ -190,6 +193,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
190193
- Leave any unpaired pq by itself.
191194
- Finally, sort the resulting groups by the earliest timestamp in each group.
192195
"""
196+
# 0) Clean secrets from messages
197+
for pq in pq_list:
198+
pq.messages = [_clean_secrets_from_message(msg) for msg in pq.messages]
199+
193200
# 1) Sort by length of messages descending (largest/most-complete first),
194201
# then by timestamp ascending for stable processing.
195202
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))
@@ -204,7 +211,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
204211

205212
# Find all potential subsets of 'sup' that are not yet used
206213
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
207-
possible_subsets = []
214+
possible_subsets: List[PartialQuestions] = []
208215
for sub in pq_list_sorted:
209216
if sub.message_id == sup.message_id:
210217
continue
@@ -261,7 +268,7 @@ def _get_question_answer_from_partial(
261268

262269
async def match_conversations(
263270
partial_question_answers: List[Optional[PartialQuestionAnswer]],
264-
) -> List[Conversation]:
271+
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
265272
"""
266273
Match partial conversations to form a complete conversation.
267274
"""
@@ -274,45 +281,47 @@ async def match_conversations(
274281

275282
# Create the conversation objects
276283
conversations = []
284+
map_q_id_to_conversation = {}
277285
for group in grouped_partial_questions:
278-
questions_answers = []
286+
questions_answers: List[QuestionAnswer] = []
279287
first_partial_qa = None
280288
for partial_question in sorted(group, key=lambda x: x.timestamp):
281289
# Partial questions don't contain the answer, so we need to find the corresponding
290+
# valid partial question answer
282291
selected_partial_qa = None
283292
for partial_qa in valid_partial_qas:
284293
if partial_question.message_id == partial_qa.partial_questions.message_id:
285294
selected_partial_qa = partial_qa
286295
break
287296

288-
# check if we have an answer, otherwise do not add it
289-
if selected_partial_qa.answer is not None:
290-
# if we don't have a first question, set it
297+
# check if we have a question and answer, otherwise do not add it
298+
if selected_partial_qa and selected_partial_qa.answer is not None:
299+
# if we don't have a first question, set it. We will use it
300+
# to set the conversation timestamp and provider
291301
first_partial_qa = first_partial_qa or selected_partial_qa
292-
question_answer = _get_question_answer_from_partial(selected_partial_qa)
293-
question_answer.question.message = parse_question_answer(
294-
question_answer.question.message
295-
)
296-
questions_answers.append(question_answer)
302+
qa = _get_question_answer_from_partial(selected_partial_qa)
303+
qa.question.message = parse_question_answer(qa.question.message)
304+
questions_answers.append(qa)
297305

298306
# only add conversation if we have some answers
299307
if len(questions_answers) > 0 and first_partial_qa is not None:
300-
conversations.append(
301-
Conversation(
302-
question_answers=questions_answers,
303-
provider=first_partial_qa.partial_questions.provider,
304-
type=first_partial_qa.partial_questions.type,
305-
chat_id=first_partial_qa.partial_questions.message_id,
306-
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
307-
)
308+
conversation = Conversation(
309+
question_answers=questions_answers,
310+
provider=first_partial_qa.partial_questions.provider,
311+
type=first_partial_qa.partial_questions.type,
312+
chat_id=first_partial_qa.partial_questions.message_id,
313+
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
308314
)
315+
for qa in questions_answers:
316+
map_q_id_to_conversation[qa.question.message_id] = conversation
317+
conversations.append(conversation)
309318

310-
return conversations
319+
return conversations, map_q_id_to_conversation
311320

312321

313322
async def parse_messages_in_conversations(
314323
prompts_outputs: List[GetPromptWithOutputsRow],
315-
) -> List[Conversation]:
324+
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
316325
"""
317326
Get all the messages from the database and return them as a list of conversations.
318327
"""
@@ -322,31 +331,21 @@ async def parse_messages_in_conversations(
322331
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
323332
partial_question_answers = [task.result() for task in tasks]
324333

325-
conversations = await match_conversations(partial_question_answers)
326-
return conversations
334+
conversations, map_q_id_to_conversation = await match_conversations(partial_question_answers)
335+
return conversations, map_q_id_to_conversation
327336

328337

329338
async def parse_row_alert_conversation(
330-
row: GetAlertsWithPromptAndOutputRow,
339+
row: Alert, map_q_id_to_conversation: Dict[str, Conversation]
331340
) -> Optional[AlertConversation]:
332341
"""
333342
Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation
334343
335344
The row contains the raw request and output strings from the pipeline.
336345
"""
337-
partial_qa = await _get_question_answer(row)
338-
if not partial_qa:
346+
conversation = map_q_id_to_conversation.get(row.prompt_id)
347+
if conversation is None:
339348
return None
340-
341-
question_answer = _get_question_answer_from_partial(partial_qa)
342-
343-
conversation = Conversation(
344-
question_answers=[question_answer],
345-
provider=row.provider,
346-
type=row.type,
347-
chat_id=row.id,
348-
conversation_timestamp=row.timestamp,
349-
)
350349
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
351350
trigger_string = None
352351
if row.trigger_string:
@@ -367,14 +366,19 @@ async def parse_row_alert_conversation(
367366

368367

369368
async def parse_get_alert_conversation(
370-
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
369+
alerts: List[Alert],
370+
prompts_outputs: List[GetPromptWithOutputsRow],
371371
) -> List[AlertConversation]:
372372
"""
373373
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
374374
AlertConversation
375375
376376
The rows contain the raw request and output strings from the pipeline.
377377
"""
378+
_, map_q_id_to_conversation = await parse_messages_in_conversations(prompts_outputs)
378379
async with asyncio.TaskGroup() as tg:
379-
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
380+
tasks = [
381+
tg.create_task(parse_row_alert_conversation(row, map_q_id_to_conversation))
382+
for row in alerts
383+
]
380384
return [task.result() for task in tasks if task.result() is not None]

src/codegate/api/v1.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[AlertConver
202202
raise HTTPException(status_code=500, detail="Internal server error")
203203

204204
try:
205-
alerts = await dbreader.get_alerts_with_prompt_and_output(ws.id)
206-
return await dashboard.parse_get_alert_conversation(alerts)
205+
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
206+
alerts = await dbreader.get_alerts_by_workspace(ws.id)
207+
return await dashboard.parse_get_alert_conversation(alerts, prompts_outputs)
207208
except Exception:
208209
raise HTTPException(status_code=500, detail="Internal server error")
209210

@@ -224,7 +225,8 @@ async def get_workspace_messages(workspace_name: str) -> List[Conversation]:
224225

225226
try:
226227
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
227-
return await dashboard.parse_messages_in_conversations(prompts_outputs)
228+
conversations, _ = await dashboard.parse_messages_in_conversations(prompts_outputs)
229+
return conversations
228230
except Exception:
229231
raise HTTPException(status_code=500, detail="Internal server error")
230232

src/codegate/db/connection.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from codegate.db.models import (
1818
ActiveWorkspace,
1919
Alert,
20-
GetAlertsWithPromptAndOutputRow,
2120
GetPromptWithOutputsRow,
2221
GetWorkspaceByNameConditions,
2322
Output,
@@ -430,9 +429,7 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO
430429
)
431430
return prompts
432431

433-
async def get_alerts_with_prompt_and_output(
434-
self, workspace_id: str
435-
) -> List[GetAlertsWithPromptAndOutputRow]:
432+
async def get_alerts_by_workspace(self, workspace_id: str) -> List[Alert]:
436433
sql = text(
437434
"""
438435
SELECT
@@ -442,24 +439,16 @@ async def get_alerts_with_prompt_and_output(
442439
a.trigger_string,
443440
a.trigger_type,
444441
a.trigger_category,
445-
a.timestamp,
446-
p.timestamp as prompt_timestamp,
447-
p.provider,
448-
p.request,
449-
p.type,
450-
o.id as output_id,
451-
o.output,
452-
o.timestamp as output_timestamp
442+
a.timestamp
453443
FROM alerts a
454-
LEFT JOIN prompts p ON p.id = a.prompt_id
455-
LEFT JOIN outputs o ON p.id = o.prompt_id
444+
INNER JOIN prompts p ON p.id = a.prompt_id
456445
WHERE p.workspace_id = :workspace_id
457446
ORDER BY a.timestamp DESC
458447
"""
459448
)
460449
conditions = {"workspace_id": workspace_id}
461450
prompts = await self._exec_select_conditions_to_pydantic(
462-
GetAlertsWithPromptAndOutputRow, sql, conditions, should_raise=True
451+
Alert, sql, conditions, should_raise=True
463452
)
464453
return prompts
465454

src/codegate/db/models.py

-17
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,6 @@ class Session(BaseModel):
6767
# Models for select queries
6868

6969

70-
class GetAlertsWithPromptAndOutputRow(BaseModel):
71-
id: Any
72-
prompt_id: Any
73-
code_snippet: Optional[Any]
74-
trigger_string: Optional[Any]
75-
trigger_type: Any
76-
trigger_category: Optional[Any]
77-
timestamp: Any
78-
prompt_timestamp: Optional[Any]
79-
provider: Optional[Any]
80-
request: Optional[Any]
81-
type: Optional[Any]
82-
output_id: Optional[Any]
83-
output: Optional[Any]
84-
output_timestamp: Optional[Any]
85-
86-
8770
class GetPromptWithOutputsRow(BaseModel):
8871
id: Any
8972
timestamp: Any

0 commit comments

Comments
 (0)