Skip to content

Commit 34b8aa2

Browse files
feat: update messages endpoint to return a conversation summary (#1247)
* feat: update messages endpoint to return a conversation summary Modify the messages endpoint to return just a conversationsummary, that will simplify the current queries. Create a different endpoint that will return a list of conversations for a given prompt id * fixes from rebase * fix lint * changes from review * fixes from review * fix pagination * decouple alerts from question/answer * fix querying prompts without alerts * clean message in list --------- Co-authored-by: Alex McGovern <[email protected]>
1 parent 7d131ec commit 34b8aa2

File tree

6 files changed

+399
-114
lines changed

6 files changed

+399
-114
lines changed

src/codegate/api/v1.py

+129-19
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33

44
import requests
55
import structlog
6-
from fastapi import APIRouter, Depends, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Query, Response
77
from fastapi.responses import StreamingResponse
88
from fastapi.routing import APIRoute
99
from pydantic import BaseModel, ValidationError
1010

11+
from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
1112
import codegate.muxing.models as mux_models
1213
from codegate import __version__
1314
from codegate.api import v1_models, v1_processing
1415
from codegate.db.connection import AlreadyExistsError, DbReader
15-
from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
16+
from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel
1617
from codegate.muxing.persona import (
1718
PersonaDoesNotExistError,
1819
PersonaManager,
@@ -419,7 +420,9 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
419420
raise HTTPException(status_code=500, detail="Internal server error")
420421

421422
try:
422-
alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value)
423+
alerts = await dbreader.get_alerts_by_workspace_or_prompt_id(
424+
workspace_id=ws.id, trigger_category=AlertSeverity.CRITICAL.value
425+
)
423426
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
424427
return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs)
425428
except Exception:
@@ -443,11 +446,12 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
443446
raise HTTPException(status_code=500, detail="Internal server error")
444447

445448
try:
446-
summary = await dbreader.get_alerts_summary_by_workspace(ws.id)
449+
summary = await dbreader.get_alerts_summary(workspace_id=ws.id)
447450
return v1_models.AlertSummary(
448-
malicious_packages=summary["codegate_context_retriever_count"],
449-
pii=summary["codegate_pii_count"],
450-
secrets=summary["codegate_secrets_count"],
451+
malicious_packages=summary.total_packages_count,
452+
pii=summary.total_pii_count,
453+
secrets=summary.total_secrets_count,
454+
total_alerts=summary.total_alerts,
451455
)
452456
except Exception:
453457
logger.exception("Error while getting alerts summary")
@@ -459,7 +463,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
459463
tags=["Workspaces"],
460464
generate_unique_id_function=uniq_name,
461465
)
462-
async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]:
466+
async def get_workspace_messages(
467+
workspace_name: str,
468+
page: int = Query(1, ge=1),
469+
page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE),
470+
filter_by_ids: Optional[List[str]] = Query(None),
471+
filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None),
472+
) -> v1_models.PaginatedMessagesResponse:
463473
"""Get messages for a workspace."""
464474
try:
465475
ws = await wscrud.get_workspace_by_name(workspace_name)
@@ -469,19 +479,119 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
469479
logger.exception("Error while getting workspace")
470480
raise HTTPException(status_code=500, detail="Internal server error")
471481

472-
try:
473-
prompts_with_output_alerts_usage = (
474-
await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(
475-
ws.id, AlertSeverity.CRITICAL.value
476-
)
477-
)
478-
conversations, _ = await v1_processing.parse_messages_in_conversations(
479-
prompts_with_output_alerts_usage
482+
offset = (page - 1) * page_size
483+
valid_conversations: List[v1_models.ConversationSummary] = []
484+
fetched_prompts = 0
485+
486+
while len(valid_conversations) < page_size:
487+
batch_size = page_size * 2 # Fetch more prompts to compensate for potential skips
488+
489+
prompts = await dbreader.get_prompts(
490+
ws.id,
491+
offset + fetched_prompts,
492+
batch_size,
493+
filter_by_ids,
494+
list([AlertSeverity.CRITICAL.value]),
495+
filter_by_alert_trigger_types,
480496
)
481-
return conversations
497+
498+
if not prompts or len(prompts) == 0:
499+
break
500+
501+
# iterate for all prompts to compose the conversation summary
502+
for prompt in prompts:
503+
fetched_prompts += 1
504+
if not prompt.request:
505+
logger.warning(f"Skipping prompt {prompt.id}. Empty request field")
506+
continue
507+
508+
messages, _ = await v1_processing.parse_request(prompt.request)
509+
if not messages or len(messages) == 0:
510+
logger.warning(f"Skipping prompt {prompt.id}. No messages found")
511+
continue
512+
513+
# message is just the first entry in the request, cleaned properly
514+
message = v1_processing.parse_question_answer(messages[0])
515+
message_obj = v1_models.ChatMessage(
516+
message=message, timestamp=prompt.timestamp, message_id=prompt.id
517+
)
518+
519+
# count total alerts for the prompt
520+
total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id)
521+
522+
# get token usage for the prompt
523+
prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id)
524+
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
525+
526+
conversation_summary = v1_models.ConversationSummary(
527+
chat_id=prompt.id,
528+
prompt=message_obj,
529+
provider=prompt.provider,
530+
type=prompt.type,
531+
conversation_timestamp=prompt.timestamp,
532+
alerts_summary=v1_models.AlertSummary(
533+
malicious_packages=total_alerts_row.total_packages_count,
534+
pii=total_alerts_row.total_pii_count,
535+
secrets=total_alerts_row.total_secrets_count,
536+
total_alerts=total_alerts_row.total_alerts,
537+
),
538+
total_alerts=total_alerts_row.total_alerts,
539+
token_usage_agg=ws_token_usage,
540+
)
541+
542+
valid_conversations.append(conversation_summary)
543+
if len(valid_conversations) >= page_size:
544+
break
545+
546+
# Fetch total message count
547+
total_count = await dbreader.get_total_messages_count_by_workspace_id(
548+
ws.id,
549+
filter_by_ids,
550+
list([AlertSeverity.CRITICAL.value]),
551+
filter_by_alert_trigger_types,
552+
)
553+
554+
return v1_models.PaginatedMessagesResponse(
555+
data=valid_conversations,
556+
limit=page_size,
557+
offset=offset,
558+
total=total_count,
559+
)
560+
561+
562+
@v1.get(
563+
"/workspaces/{workspace_name}/messages/{prompt_id}",
564+
tags=["Workspaces"],
565+
generate_unique_id_function=uniq_name,
566+
)
567+
async def get_messages_by_prompt_id(
568+
workspace_name: str,
569+
prompt_id: str,
570+
) -> v1_models.Conversation:
571+
"""Get messages for a workspace."""
572+
try:
573+
ws = await wscrud.get_workspace_by_name(workspace_name)
574+
except crud.WorkspaceDoesNotExistError:
575+
raise HTTPException(status_code=404, detail="Workspace does not exist")
482576
except Exception:
483-
logger.exception("Error while getting messages")
577+
logger.exception("Error while getting workspace")
484578
raise HTTPException(status_code=500, detail="Internal server error")
579+
prompts_outputs = await dbreader.get_prompts_with_output(
580+
workspace_id=ws.id, prompt_id=prompt_id
581+
)
582+
583+
# get all alerts for the prompt
584+
alerts = await dbreader.get_alerts_by_workspace_or_prompt_id(
585+
workspace_id=ws.id, prompt_id=prompt_id, trigger_category=AlertSeverity.CRITICAL.value
586+
)
587+
deduped_alerts = await v1_processing.remove_duplicate_alerts(alerts)
588+
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
589+
if not conversations:
590+
raise HTTPException(status_code=404, detail="Conversation not found")
591+
592+
conversation = conversations[0]
593+
conversation.alerts = deduped_alerts
594+
return conversation
485595

486596

487597
@v1.get(
@@ -665,7 +775,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
665775
raise HTTPException(status_code=500, detail="Internal server error")
666776

667777
try:
668-
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
778+
prompts_outputs = await dbreader.get_prompts_with_output(workspace_id=ws.id)
669779
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
670780
return ws_token_usage
671781
except Exception:

src/codegate/api/v1_models.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class AlertSummary(pydantic.BaseModel):
191191
malicious_packages: int
192192
pii: int
193193
secrets: int
194+
total_alerts: int
194195

195196

196197
class PartialQuestionAnswer(pydantic.BaseModel):
@@ -201,7 +202,6 @@ class PartialQuestionAnswer(pydantic.BaseModel):
201202
partial_questions: PartialQuestions
202203
answer: Optional[ChatMessage]
203204
model_token_usage: TokenUsageByModel
204-
alerts: List[Alert] = []
205205

206206

207207
class Conversation(pydantic.BaseModel):
@@ -215,7 +215,21 @@ class Conversation(pydantic.BaseModel):
215215
chat_id: str
216216
conversation_timestamp: datetime.datetime
217217
token_usage_agg: Optional[TokenUsageAggregate]
218-
alerts: List[Alert] = []
218+
alerts: Optional[List[Alert]] = []
219+
220+
221+
class ConversationSummary(pydantic.BaseModel):
222+
"""
223+
Represents a conversation summary.
224+
"""
225+
226+
chat_id: str
227+
prompt: ChatMessage
228+
alerts_summary: AlertSummary
229+
token_usage_agg: Optional[TokenUsageAggregate]
230+
provider: Optional[str]
231+
type: QuestionType
232+
conversation_timestamp: datetime.datetime
219233

220234

221235
class AlertConversation(pydantic.BaseModel):
@@ -333,3 +347,10 @@ class PersonaUpdateRequest(pydantic.BaseModel):
333347

334348
new_name: str
335349
new_description: str
350+
351+
352+
class PaginatedMessagesResponse(pydantic.BaseModel):
353+
data: List[ConversationSummary]
354+
limit: int
355+
offset: int
356+
total: int

src/codegate/api/v1_processing.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,10 @@ async def _get_partial_question_answer(
202202
model=model, token_usage=token_usage, provider_type=provider
203203
)
204204

205-
alerts: List[v1_models.Alert] = [
206-
v1_models.Alert.from_db_model(db_alert) for db_alert in row.alerts
207-
]
208-
209205
return PartialQuestionAnswer(
210206
partial_questions=request_message,
211207
answer=output_message,
212208
model_token_usage=model_token_usage,
213-
alerts=alerts,
214209
)
215210

216211

@@ -374,7 +369,7 @@ async def match_conversations(
374369
for group in grouped_partial_questions:
375370
questions_answers: List[QuestionAnswer] = []
376371
token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage())
377-
alerts: List[v1_models.Alert] = []
372+
378373
first_partial_qa = None
379374
for partial_question in sorted(group, key=lambda x: x.timestamp):
380375
# Partial questions don't contain the answer, so we need to find the corresponding
@@ -398,8 +393,6 @@ async def match_conversations(
398393
qa = _get_question_answer_from_partial(selected_partial_qa)
399394
qa.question.message = parse_question_answer(qa.question.message)
400395
questions_answers.append(qa)
401-
deduped_alerts = await remove_duplicate_alerts(selected_partial_qa.alerts)
402-
alerts.extend(deduped_alerts)
403396
token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage)
404397

405398
# if we have a conversation with at least one question and answer
@@ -413,7 +406,6 @@ async def match_conversations(
413406
chat_id=first_partial_qa.partial_questions.message_id,
414407
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
415408
token_usage_agg=token_usage_agg,
416-
alerts=alerts,
417409
)
418410
for qa in questions_answers:
419411
map_q_id_to_conversation[qa.question.message_id] = conversation

src/codegate/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"llamacpp": "./codegate_volume/models", # Default LlamaCpp model path
2626
}
2727

28+
API_DEFAULT_PAGE_SIZE = 50
29+
API_MAX_PAGE_SIZE = 100
30+
2831

2932
@dataclass
3033
class Config:

0 commit comments

Comments
 (0)