Skip to content

Commit 25a67d0

Browse files
Properly parse conversations into alerts
Closes: #420 This PR introduces: - Have full conversations in alerts. Until now we were only showing the message where the alert came from in `*/alerts` - Match the `chat_id` in the alert conversation to the ones returned in `*/messages` - Fixes a grouping issue in which conversations with secrets were not grouped correctly because of the `REDACTED<>` text.
1 parent e6372d7 commit 25a67d0

File tree

5 files changed

+189
-77
lines changed

5 files changed

+189
-77
lines changed
+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import asyncio
2+
from typing import AsyncGenerator, List, Optional
3+
4+
import requests
5+
import structlog
6+
from fastapi import APIRouter, Depends, HTTPException
7+
from fastapi.responses import StreamingResponse
8+
from fastapi.routing import APIRoute
9+
10+
from codegate import __version__
11+
from codegate.api.dashboard.post_processing import (
12+
parse_get_alert_conversation,
13+
parse_messages_in_conversations,
14+
)
15+
from codegate.api.dashboard.request_models import AlertConversation, Conversation
16+
from codegate.db.connection import DbReader, alert_queue
17+
from codegate.workspaces import crud
18+
19+
logger = structlog.get_logger("codegate")
20+
21+
dashboard_router = APIRouter()
22+
db_reader = None
23+
24+
wscrud = crud.WorkspaceCrud()
25+
26+
27+
def uniq_name(route: APIRoute):
28+
return f"v1_{route.name}"
29+
30+
31+
def get_db_reader():
32+
global db_reader
33+
if db_reader is None:
34+
db_reader = DbReader()
35+
return db_reader
36+
37+
38+
def fetch_latest_version() -> str:
39+
url = "https://api.github.com/repos/stacklok/codegate/releases/latest"
40+
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
41+
response = requests.get(url, headers=headers, timeout=5)
42+
response.raise_for_status()
43+
data = response.json()
44+
return data.get("tag_name", "unknown")
45+
46+
47+
@dashboard_router.get(
48+
"/dashboard/messages", tags=["Dashboard"], generate_unique_id_function=uniq_name
49+
)
50+
def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversation]:
51+
"""
52+
Get all the messages from the database and return them as a list of conversations.
53+
"""
54+
try:
55+
active_ws = asyncio.run(wscrud.get_active_workspace())
56+
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output(active_ws.id))
57+
conversations, _ = asyncio.run(parse_messages_in_conversations(prompts_outputs))
58+
return conversations
59+
except Exception as e:
60+
logger.error(f"Error getting messages: {str(e)}")
61+
raise HTTPException(status_code=500, detail="Internal server error")
62+
63+
64+
@dashboard_router.get(
65+
"/dashboard/alerts", tags=["Dashboard"], generate_unique_id_function=uniq_name
66+
)
67+
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
68+
"""
69+
Get all the messages from the database and return them as a list of conversations.
70+
"""
71+
try:
72+
active_ws = asyncio.run(wscrud.get_active_workspace())
73+
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output(active_ws.id))
74+
alerts = asyncio.run(db_reader.get_alerts_by_workspace(active_ws.id))
75+
return asyncio.run(parse_get_alert_conversation(alerts, prompts_outputs))
76+
except Exception as e:
77+
logger.error(f"Error getting alerts: {str(e)}")
78+
raise HTTPException(status_code=500, detail="Internal server error")
79+
80+
81+
async def generate_sse_events() -> AsyncGenerator[str, None]:
82+
"""
83+
SSE generator from queue
84+
"""
85+
while True:
86+
message = await alert_queue.get()
87+
yield f"data: {message}\n\n"
88+
89+
90+
@dashboard_router.get(
91+
"/dashboard/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name
92+
)
93+
async def stream_sse():
94+
"""
95+
Send alerts event
96+
"""
97+
return StreamingResponse(generate_sse_events(), media_type="text/event-stream")
98+
99+
100+
@dashboard_router.get(
101+
"/dashboard/version", tags=["Dashboard"], generate_unique_id_function=uniq_name
102+
)
103+
def version_check():
104+
try:
105+
latest_version = fetch_latest_version()
106+
107+
# normalize the versions as github will return them with a 'v' prefix
108+
current_version = __version__.lstrip("v")
109+
latest_version_stripped = latest_version.lstrip("v")
110+
111+
is_latest: bool = latest_version_stripped == current_version
112+
113+
return {
114+
"current_version": current_version,
115+
"latest_version": latest_version_stripped,
116+
"is_latest": is_latest,
117+
"error": None,
118+
}
119+
except requests.RequestException as e:
120+
logger.error(f"RequestException: {str(e)}")
121+
return {
122+
"current_version": __version__,
123+
"latest_version": "unknown",
124+
"is_latest": None,
125+
"error": "An error occurred while fetching the latest version",
126+
}
127+
except Exception as e:
128+
logger.error(f"Unexpected error: {str(e)}")
129+
return {
130+
"current_version": __version__,
131+
"latest_version": "unknown",
132+
"is_latest": None,
133+
"error": "An unexpected error occurred",
134+
}

src/codegate/api/v1.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
206206

207207
try:
208208
alerts = await dbreader.get_alerts_with_prompt_and_output(ws.id)
209-
return await v1_processing.parse_get_alert_conversation(alerts)
209+
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
210+
return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs)
210211
except Exception:
211212
raise HTTPException(status_code=500, detail="Internal server error")
212213

@@ -227,7 +228,8 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
227228

228229
try:
229230
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
230-
return await v1_processing.parse_messages_in_conversations(prompts_outputs)
231+
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
232+
return conversations
231233
except Exception:
232234
raise HTTPException(status_code=500, detail="Internal server error")
233235

src/codegate/api/v1_processing.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import re
44
from collections import defaultdict
5-
from typing import AsyncGenerator, List, Optional, Union
5+
from typing import AsyncGenerator, Dict, List, Optional, Tuple
66

77
import requests
88
import structlog
@@ -16,7 +16,7 @@
1616
QuestionAnswer,
1717
)
1818
from codegate.db.connection import alert_queue
19-
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
19+
from codegate.db.models import Alert, GetPromptWithOutputsRow
2020

2121
logger = structlog.get_logger("codegate")
2222

@@ -144,9 +144,7 @@ def _parse_single_output(single_output: dict) -> str:
144144
return full_output_message
145145

146146

147-
async def _get_question_answer(
148-
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
149-
) -> Optional[PartialQuestionAnswer]:
147+
async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]:
150148
"""
151149
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
152150
@@ -195,6 +193,11 @@ def parse_question_answer(input_text: str) -> str:
195193
return input_text
196194

197195

196+
def _clean_secrets_from_message(message: str) -> str:
197+
pattern = re.compile(r"REDACTED<(\$?[^>]+)>")
198+
return pattern.sub("REDACTED_SECRET", message)
199+
200+
198201
def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
199202
"""
200203
A PartialQuestion is an object that contains several user messages provided from a
@@ -210,6 +213,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
210213
- Leave any unpaired pq by itself.
211214
- Finally, sort the resulting groups by the earliest timestamp in each group.
212215
"""
216+
# 0) Clean secrets from messages
217+
for pq in pq_list:
218+
pq.messages = [_clean_secrets_from_message(msg) for msg in pq.messages]
219+
213220
# 1) Sort by length of messages descending (largest/most-complete first),
214221
# then by timestamp ascending for stable processing.
215222
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))
@@ -224,7 +231,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
224231

225232
# Find all potential subsets of 'sup' that are not yet used
226233
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
227-
possible_subsets = []
234+
possible_subsets: List[PartialQuestions] = []
228235
for sub in pq_list_sorted:
229236
if sub.message_id == sup.message_id:
230237
continue
@@ -281,7 +288,7 @@ def _get_question_answer_from_partial(
281288

282289
async def match_conversations(
283290
partial_question_answers: List[Optional[PartialQuestionAnswer]],
284-
) -> List[Conversation]:
291+
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
285292
"""
286293
Match partial conversations to form a complete conversation.
287294
"""
@@ -294,45 +301,47 @@ async def match_conversations(
294301

295302
# Create the conversation objects
296303
conversations = []
304+
map_q_id_to_conversation = {}
297305
for group in grouped_partial_questions:
298-
questions_answers = []
306+
questions_answers: List[QuestionAnswer] = []
299307
first_partial_qa = None
300308
for partial_question in sorted(group, key=lambda x: x.timestamp):
301309
# Partial questions don't contain the answer, so we need to find the corresponding
310+
# valid partial question answer
302311
selected_partial_qa = None
303312
for partial_qa in valid_partial_qas:
304313
if partial_question.message_id == partial_qa.partial_questions.message_id:
305314
selected_partial_qa = partial_qa
306315
break
307316

308-
# check if we have an answer, otherwise do not add it
309-
if selected_partial_qa.answer is not None:
310-
# if we don't have a first question, set it
317+
# check if we have a question and answer, otherwise do not add it
318+
if selected_partial_qa and selected_partial_qa.answer is not None:
319+
# if we don't have a first question, set it. We will use it
320+
# to set the conversation timestamp and provider
311321
first_partial_qa = first_partial_qa or selected_partial_qa
312-
question_answer = _get_question_answer_from_partial(selected_partial_qa)
313-
question_answer.question.message = parse_question_answer(
314-
question_answer.question.message
315-
)
316-
questions_answers.append(question_answer)
322+
qa = _get_question_answer_from_partial(selected_partial_qa)
323+
qa.question.message = parse_question_answer(qa.question.message)
324+
questions_answers.append(qa)
317325

318326
# only add conversation if we have some answers
319327
if len(questions_answers) > 0 and first_partial_qa is not None:
320-
conversations.append(
321-
Conversation(
322-
question_answers=questions_answers,
323-
provider=first_partial_qa.partial_questions.provider,
324-
type=first_partial_qa.partial_questions.type,
325-
chat_id=first_partial_qa.partial_questions.message_id,
326-
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
327-
)
328+
conversation = Conversation(
329+
question_answers=questions_answers,
330+
provider=first_partial_qa.partial_questions.provider,
331+
type=first_partial_qa.partial_questions.type,
332+
chat_id=first_partial_qa.partial_questions.message_id,
333+
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
328334
)
335+
for qa in questions_answers:
336+
map_q_id_to_conversation[qa.question.message_id] = conversation
337+
conversations.append(conversation)
329338

330-
return conversations
339+
return conversations, map_q_id_to_conversation
331340

332341

333342
async def parse_messages_in_conversations(
334343
prompts_outputs: List[GetPromptWithOutputsRow],
335-
) -> List[Conversation]:
344+
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
336345
"""
337346
Get all the messages from the database and return them as a list of conversations.
338347
"""
@@ -342,31 +351,21 @@ async def parse_messages_in_conversations(
342351
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
343352
partial_question_answers = [task.result() for task in tasks]
344353

345-
conversations = await match_conversations(partial_question_answers)
346-
return conversations
354+
conversations, map_q_id_to_conversation = await match_conversations(partial_question_answers)
355+
return conversations, map_q_id_to_conversation
347356

348357

349358
async def parse_row_alert_conversation(
350-
row: GetAlertsWithPromptAndOutputRow,
359+
row: Alert, map_q_id_to_conversation: Dict[str, Conversation]
351360
) -> Optional[AlertConversation]:
352361
"""
353362
Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation
354363
355364
The row contains the raw request and output strings from the pipeline.
356365
"""
357-
partial_qa = await _get_question_answer(row)
358-
if not partial_qa:
366+
conversation = map_q_id_to_conversation.get(row.prompt_id)
367+
if conversation is None:
359368
return None
360-
361-
question_answer = _get_question_answer_from_partial(partial_qa)
362-
363-
conversation = Conversation(
364-
question_answers=[question_answer],
365-
provider=row.provider,
366-
type=row.type,
367-
chat_id=row.id,
368-
conversation_timestamp=row.timestamp,
369-
)
370369
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
371370
trigger_string = None
372371
if row.trigger_string:
@@ -387,14 +386,19 @@ async def parse_row_alert_conversation(
387386

388387

389388
async def parse_get_alert_conversation(
390-
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
389+
alerts: List[Alert],
390+
prompts_outputs: List[GetPromptWithOutputsRow],
391391
) -> List[AlertConversation]:
392392
"""
393393
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
394394
AlertConversation
395395
396396
The rows contain the raw request and output strings from the pipeline.
397397
"""
398+
_, map_q_id_to_conversation = await parse_messages_in_conversations(prompts_outputs)
398399
async with asyncio.TaskGroup() as tg:
399-
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
400+
tasks = [
401+
tg.create_task(parse_row_alert_conversation(row, map_q_id_to_conversation))
402+
for row in alerts
403+
]
400404
return [task.result() for task in tasks if task.result() is not None]

src/codegate/db/connection.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from codegate.db.models import (
1818
ActiveWorkspace,
1919
Alert,
20-
GetAlertsWithPromptAndOutputRow,
2120
GetPromptWithOutputsRow,
2221
GetWorkspaceByNameConditions,
2322
Output,
@@ -430,9 +429,7 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO
430429
)
431430
return prompts
432431

433-
async def get_alerts_with_prompt_and_output(
434-
self, workspace_id: str
435-
) -> List[GetAlertsWithPromptAndOutputRow]:
432+
async def get_alerts_by_workspace(self, workspace_id: str) -> List[Alert]:
436433
sql = text(
437434
"""
438435
SELECT
@@ -442,24 +439,16 @@ async def get_alerts_with_prompt_and_output(
442439
a.trigger_string,
443440
a.trigger_type,
444441
a.trigger_category,
445-
a.timestamp,
446-
p.timestamp as prompt_timestamp,
447-
p.provider,
448-
p.request,
449-
p.type,
450-
o.id as output_id,
451-
o.output,
452-
o.timestamp as output_timestamp
442+
a.timestamp
453443
FROM alerts a
454-
LEFT JOIN prompts p ON p.id = a.prompt_id
455-
LEFT JOIN outputs o ON p.id = o.prompt_id
444+
INNER JOIN prompts p ON p.id = a.prompt_id
456445
WHERE p.workspace_id = :workspace_id
457446
ORDER BY a.timestamp DESC
458447
"""
459448
)
460449
conditions = {"workspace_id": workspace_id}
461450
prompts = await self._exec_select_conditions_to_pydantic(
462-
GetAlertsWithPromptAndOutputRow, sql, conditions, should_raise=True
451+
Alert, sql, conditions, should_raise=True
463452
)
464453
return prompts
465454

0 commit comments

Comments
 (0)