2
2
import json
3
3
import re
4
4
from collections import defaultdict
5
- from typing import List , Optional , Union
5
+ from typing import Dict , List , Optional , Tuple
6
6
7
7
import structlog
8
8
14
14
PartialQuestions ,
15
15
QuestionAnswer ,
16
16
)
17
- from codegate .db .models import GetAlertsWithPromptAndOutputRow , GetPromptWithOutputsRow
17
+ from codegate .db .models import Alert , GetPromptWithOutputsRow
18
18
19
19
logger = structlog .get_logger ("codegate" )
20
20
@@ -124,9 +124,7 @@ def _parse_single_output(single_output: dict) -> str:
124
124
return full_output_message
125
125
126
126
127
- async def _get_question_answer (
128
- row : Union [GetPromptWithOutputsRow , GetAlertsWithPromptAndOutputRow ]
129
- ) -> Optional [PartialQuestionAnswer ]:
127
+ async def _get_question_answer (row : GetPromptWithOutputsRow ) -> Optional [PartialQuestionAnswer ]:
130
128
"""
131
129
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
132
130
@@ -175,6 +173,11 @@ def parse_question_answer(input_text: str) -> str:
175
173
return input_text
176
174
177
175
176
+ def _clean_secrets_from_message (message : str ) -> str :
177
+ pattern = re .compile (r"REDACTED<(\$?[^>]+)>" )
178
+ return pattern .sub ("REDACTED_SECRET" , message )
179
+
180
+
178
181
def _group_partial_messages (pq_list : List [PartialQuestions ]) -> List [List [PartialQuestions ]]:
179
182
"""
180
183
A PartialQuestion is an object that contains several user messages provided from a
@@ -190,6 +193,10 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
190
193
- Leave any unpaired pq by itself.
191
194
- Finally, sort the resulting groups by the earliest timestamp in each group.
192
195
"""
196
+ # 0) Clean secrets from messages
197
+ for pq in pq_list :
198
+ pq .messages = [_clean_secrets_from_message (msg ) for msg in pq .messages ]
199
+
193
200
# 1) Sort by length of messages descending (largest/most-complete first),
194
201
# then by timestamp ascending for stable processing.
195
202
pq_list_sorted = sorted (pq_list , key = lambda x : (- len (x .messages ), x .timestamp ))
@@ -204,7 +211,7 @@ def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[Partia
204
211
205
212
# Find all potential subsets of 'sup' that are not yet used
206
213
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
207
- possible_subsets = []
214
+ possible_subsets : List [ PartialQuestions ] = []
208
215
for sub in pq_list_sorted :
209
216
if sub .message_id == sup .message_id :
210
217
continue
@@ -261,7 +268,7 @@ def _get_question_answer_from_partial(
261
268
262
269
async def match_conversations (
263
270
partial_question_answers : List [Optional [PartialQuestionAnswer ]],
264
- ) -> List [Conversation ]:
271
+ ) -> Tuple [ List [Conversation ], Dict [ str , Conversation ] ]:
265
272
"""
266
273
Match partial conversations to form a complete conversation.
267
274
"""
@@ -274,45 +281,47 @@ async def match_conversations(
274
281
275
282
# Create the conversation objects
276
283
conversations = []
284
+ map_q_id_to_conversation = {}
277
285
for group in grouped_partial_questions :
278
- questions_answers = []
286
+ questions_answers : List [ QuestionAnswer ] = []
279
287
first_partial_qa = None
280
288
for partial_question in sorted (group , key = lambda x : x .timestamp ):
281
289
# Partial questions don't contain the answer, so we need to find the corresponding
290
+ # valid partial question answer
282
291
selected_partial_qa = None
283
292
for partial_qa in valid_partial_qas :
284
293
if partial_question .message_id == partial_qa .partial_questions .message_id :
285
294
selected_partial_qa = partial_qa
286
295
break
287
296
288
- # check if we have an answer, otherwise do not add it
289
- if selected_partial_qa .answer is not None :
290
- # if we don't have a first question, set it
297
+ # check if we have a question and answer, otherwise do not add it
298
+ if selected_partial_qa and selected_partial_qa .answer is not None :
299
+ # if we don't have a first question, set it. We will use it
300
+ # to set the conversation timestamp and provider
291
301
first_partial_qa = first_partial_qa or selected_partial_qa
292
- question_answer = _get_question_answer_from_partial (selected_partial_qa )
293
- question_answer .question .message = parse_question_answer (
294
- question_answer .question .message
295
- )
296
- questions_answers .append (question_answer )
302
+ qa = _get_question_answer_from_partial (selected_partial_qa )
303
+ qa .question .message = parse_question_answer (qa .question .message )
304
+ questions_answers .append (qa )
297
305
298
306
# only add conversation if we have some answers
299
307
if len (questions_answers ) > 0 and first_partial_qa is not None :
300
- conversations .append (
301
- Conversation (
302
- question_answers = questions_answers ,
303
- provider = first_partial_qa .partial_questions .provider ,
304
- type = first_partial_qa .partial_questions .type ,
305
- chat_id = first_partial_qa .partial_questions .message_id ,
306
- conversation_timestamp = first_partial_qa .partial_questions .timestamp ,
307
- )
308
+ conversation = Conversation (
309
+ question_answers = questions_answers ,
310
+ provider = first_partial_qa .partial_questions .provider ,
311
+ type = first_partial_qa .partial_questions .type ,
312
+ chat_id = first_partial_qa .partial_questions .message_id ,
313
+ conversation_timestamp = first_partial_qa .partial_questions .timestamp ,
308
314
)
315
+ for qa in questions_answers :
316
+ map_q_id_to_conversation [qa .question .message_id ] = conversation
317
+ conversations .append (conversation )
309
318
310
- return conversations
319
+ return conversations , map_q_id_to_conversation
311
320
312
321
313
322
async def parse_messages_in_conversations (
314
323
prompts_outputs : List [GetPromptWithOutputsRow ],
315
- ) -> List [Conversation ]:
324
+ ) -> Tuple [ List [Conversation ], Dict [ str , Conversation ] ]:
316
325
"""
317
326
Get all the messages from the database and return them as a list of conversations.
318
327
"""
@@ -322,31 +331,21 @@ async def parse_messages_in_conversations(
322
331
tasks = [tg .create_task (_get_question_answer (row )) for row in prompts_outputs ]
323
332
partial_question_answers = [task .result () for task in tasks ]
324
333
325
- conversations = await match_conversations (partial_question_answers )
326
- return conversations
334
+ conversations , map_q_id_to_conversation = await match_conversations (partial_question_answers )
335
+ return conversations , map_q_id_to_conversation
327
336
328
337
329
338
async def parse_row_alert_conversation (
330
- row : GetAlertsWithPromptAndOutputRow ,
339
+ row : Alert , map_q_id_to_conversation : Dict [ str , Conversation ]
331
340
) -> Optional [AlertConversation ]:
332
341
"""
333
342
Parse a row from the get_alerts_with_prompt_and_output query and return a Conversation
334
343
335
344
The row contains the raw request and output strings from the pipeline.
336
345
"""
337
- partial_qa = await _get_question_answer (row )
338
- if not partial_qa :
346
+ conversation = map_q_id_to_conversation . get (row . prompt_id )
347
+ if conversation is None :
339
348
return None
340
-
341
- question_answer = _get_question_answer_from_partial (partial_qa )
342
-
343
- conversation = Conversation (
344
- question_answers = [question_answer ],
345
- provider = row .provider ,
346
- type = row .type ,
347
- chat_id = row .id ,
348
- conversation_timestamp = row .timestamp ,
349
- )
350
349
code_snippet = json .loads (row .code_snippet ) if row .code_snippet else None
351
350
trigger_string = None
352
351
if row .trigger_string :
@@ -367,14 +366,19 @@ async def parse_row_alert_conversation(
367
366
368
367
369
368
async def parse_get_alert_conversation (
370
- alerts_conversations : List [GetAlertsWithPromptAndOutputRow ],
369
+ alerts : List [Alert ],
370
+ prompts_outputs : List [GetPromptWithOutputsRow ],
371
371
) -> List [AlertConversation ]:
372
372
"""
373
373
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
374
374
AlertConversation
375
375
376
376
The rows contain the raw request and output strings from the pipeline.
377
377
"""
378
+ _ , map_q_id_to_conversation = await parse_messages_in_conversations (prompts_outputs )
378
379
async with asyncio .TaskGroup () as tg :
379
- tasks = [tg .create_task (parse_row_alert_conversation (row )) for row in alerts_conversations ]
380
+ tasks = [
381
+ tg .create_task (parse_row_alert_conversation (row , map_q_id_to_conversation ))
382
+ for row in alerts
383
+ ]
380
384
return [task .result () for task in tasks if task .result () is not None ]
0 commit comments