2
2
import json
3
3
import re
4
4
from collections import defaultdict
5
- from typing import AsyncGenerator , List , Optional , Union
5
+ from typing import AsyncGenerator , Dict , List , Optional , Tuple
6
6
7
7
import requests
8
8
import structlog
16
16
QuestionAnswer ,
17
17
)
18
18
from codegate .db .connection import alert_queue
19
- from codegate .db .models import GetAlertsWithPromptAndOutputRow , GetPromptWithOutputsRow
19
+ from codegate .db .models import Alert , GetPromptWithOutputsRow
20
20
21
21
logger = structlog .get_logger ("codegate" )
22
22
@@ -144,9 +144,7 @@ def _parse_single_output(single_output: dict) -> str:
144
144
return full_output_message
145
145
146
146
147
- async def _get_question_answer (
148
- row : Union [GetPromptWithOutputsRow , GetAlertsWithPromptAndOutputRow ]
149
- ) -> Optional [PartialQuestionAnswer ]:
147
+ async def _get_question_answer (row : GetPromptWithOutputsRow ) -> Optional [PartialQuestionAnswer ]:
150
148
"""
151
149
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
152
150
@@ -195,6 +193,11 @@ def parse_question_answer(input_text: str) -> str:
195
193
return input_text
196
194
197
195
196
+ def _clean_secrets_from_message (message : str ) -> str :
197
+ pattern = re .compile (r"REDACTED<(\$?[^>]+)>" )
198
+ return pattern .sub ("REDACTED_SECRET" , message )
199
+
200
+
198
201
def _group_partial_messages (pq_list : List [PartialQuestions ]) -> List [List [PartialQuestions ]]:
199
202
"""
200
203
A PartialQuestion is an object that contains several user messages provided from a
@@ -210,6 +213,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
210
213
- Leave any unpaired pq by itself.
211
214
- Finally, sort the resulting groups by the earliest timestamp in each group.
212
215
"""
216
+ # 0) Clean secrets from messages
217
+ for pq in pq_list :
218
+ pq .messages = [_clean_secrets_from_message (msg ) for msg in pq .messages ]
219
+
213
220
# 1) Sort by length of messages descending (largest/most-complete first),
214
221
# then by timestamp ascending for stable processing.
215
222
pq_list_sorted = sorted (pq_list , key = lambda x : (- len (x .messages ), x .timestamp ))
@@ -224,7 +231,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
224
231
225
232
# Find all potential subsets of 'sup' that are not yet used
226
233
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
227
- possible_subsets = []
234
+ possible_subsets : List [ PartialQuestions ] = []
228
235
for sub in pq_list_sorted :
229
236
if sub .message_id == sup .message_id :
230
237
continue
@@ -281,7 +288,7 @@ def _get_question_answer_from_partial(
281
288
282
289
async def match_conversations (
283
290
partial_question_answers : List [Optional [PartialQuestionAnswer ]],
284
- ) -> List [Conversation ]:
291
+ ) -> Tuple [ List [Conversation ], Dict [ str , Conversation ] ]:
285
292
"""
286
293
Match partial conversations to form a complete conversation.
287
294
"""
@@ -294,45 +301,47 @@ async def match_conversations(
294
301
295
302
# Create the conversation objects
296
303
conversations = []
304
+ map_q_id_to_conversation = {}
297
305
for group in grouped_partial_questions :
298
- questions_answers = []
306
+ questions_answers : List [ QuestionAnswer ] = []
299
307
first_partial_qa = None
300
308
for partial_question in sorted (group , key = lambda x : x .timestamp ):
301
309
# Partial questions don't contain the answer, so we need to find the corresponding
310
+ # valid partial question answer
302
311
selected_partial_qa = None
303
312
for partial_qa in valid_partial_qas :
304
313
if partial_question .message_id == partial_qa .partial_questions .message_id :
305
314
selected_partial_qa = partial_qa
306
315
break
307
316
308
- # check if we have an answer, otherwise do not add it
309
- if selected_partial_qa .answer is not None :
310
- # if we don't have a first question, set it
317
+ # check if we have a question and answer, otherwise do not add it
318
+ if selected_partial_qa and selected_partial_qa .answer is not None :
319
+ # if we don't have a first question, set it. We will use it
320
+ # to set the conversation timestamp and provider
311
321
first_partial_qa = first_partial_qa or selected_partial_qa
312
- question_answer = _get_question_answer_from_partial (selected_partial_qa )
313
- question_answer .question .message = parse_question_answer (
314
- question_answer .question .message
315
- )
316
- questions_answers .append (question_answer )
322
+ qa = _get_question_answer_from_partial (selected_partial_qa )
323
+ qa .question .message = parse_question_answer (qa .question .message )
324
+ questions_answers .append (qa )
317
325
318
326
# only add conversation if we have some answers
319
327
if len (questions_answers ) > 0 and first_partial_qa is not None :
320
- conversations .append (
321
- Conversation (
322
- question_answers = questions_answers ,
323
- provider = first_partial_qa .partial_questions .provider ,
324
- type = first_partial_qa .partial_questions .type ,
325
- chat_id = first_partial_qa .partial_questions .message_id ,
326
- conversation_timestamp = first_partial_qa .partial_questions .timestamp ,
327
- )
328
+ conversation = Conversation (
329
+ question_answers = questions_answers ,
330
+ provider = first_partial_qa .partial_questions .provider ,
331
+ type = first_partial_qa .partial_questions .type ,
332
+ chat_id = first_partial_qa .partial_questions .message_id ,
333
+ conversation_timestamp = first_partial_qa .partial_questions .timestamp ,
328
334
)
335
+ for qa in questions_answers :
336
+ map_q_id_to_conversation [qa .question .message_id ] = conversation
337
+ conversations .append (conversation )
329
338
330
- return conversations
339
+ return conversations , map_q_id_to_conversation
331
340
332
341
333
342
async def parse_messages_in_conversations (
334
343
prompts_outputs : List [GetPromptWithOutputsRow ],
335
- ) -> List [Conversation ]:
344
+ ) -> Tuple [ List [Conversation ], Dict [ str , Conversation ] ]:
336
345
"""
337
346
Get all the messages from the database and return them as a list of conversations.
338
347
"""
@@ -342,31 +351,21 @@ async def parse_messages_in_conversations(
342
351
tasks = [tg .create_task (_get_question_answer (row )) for row in prompts_outputs ]
343
352
partial_question_answers = [task .result () for task in tasks ]
344
353
345
- conversations = await match_conversations (partial_question_answers )
346
- return conversations
354
+ conversations , map_q_id_to_conversation = await match_conversations (partial_question_answers )
355
+ return conversations , map_q_id_to_conversation
347
356
348
357
349
358
async def parse_row_alert_conversation (
350
- row : GetAlertsWithPromptAndOutputRow ,
359
+ row : Alert , map_q_id_to_conversation : Dict [ str , Conversation ]
351
360
) -> Optional [AlertConversation ]:
352
361
"""
353
362
Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation
354
363
355
364
The row contains the raw request and output strings from the pipeline.
356
365
"""
357
- partial_qa = await _get_question_answer (row )
358
- if not partial_qa :
366
+ conversation = map_q_id_to_conversation . get (row . prompt_id )
367
+ if conversation is None :
359
368
return None
360
-
361
- question_answer = _get_question_answer_from_partial (partial_qa )
362
-
363
- conversation = Conversation (
364
- question_answers = [question_answer ],
365
- provider = row .provider ,
366
- type = row .type ,
367
- chat_id = row .id ,
368
- conversation_timestamp = row .timestamp ,
369
- )
370
369
code_snippet = json .loads (row .code_snippet ) if row .code_snippet else None
371
370
trigger_string = None
372
371
if row .trigger_string :
@@ -387,14 +386,19 @@ async def parse_row_alert_conversation(
387
386
388
387
389
388
async def parse_get_alert_conversation (
390
- alerts_conversations : List [GetAlertsWithPromptAndOutputRow ],
389
+ alerts : List [Alert ],
390
+ prompts_outputs : List [GetPromptWithOutputsRow ],
391
391
) -> List [AlertConversation ]:
392
392
"""
393
393
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
394
394
AlertConversation
395
395
396
396
The rows contain the raw request and output strings from the pipeline.
397
397
"""
398
+ _ , map_q_id_to_conversation = await parse_messages_in_conversations (prompts_outputs )
398
399
async with asyncio .TaskGroup () as tg :
399
- tasks = [tg .create_task (parse_row_alert_conversation (row )) for row in alerts_conversations ]
400
+ tasks = [
401
+ tg .create_task (parse_row_alert_conversation (row , map_q_id_to_conversation ))
402
+ for row in alerts
403
+ ]
400
404
return [task .result () for task in tasks if task .result () is not None ]
0 commit comments