Skip to content

Commit e9f273a

Browse files
committed
feat: update messages endpoint to return a conversation summary
Modify the messages endpoint to return just a conversationsummary, that will simplify the current queries. Create a different endpoint that will return a list of conversations for a given prompt id
1 parent 96aa48d commit e9f273a

File tree

5 files changed

+254
-99
lines changed

5 files changed

+254
-99
lines changed

src/codegate/api/v1.py

+93-17
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33

44
import requests
55
import structlog
6-
from fastapi import APIRouter, Depends, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Query, Response
77
from fastapi.responses import StreamingResponse
88
from fastapi.routing import APIRoute
99
from pydantic import BaseModel, ValidationError
1010

11+
from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
1112
import codegate.muxing.models as mux_models
1213
from codegate import __version__
1314
from codegate.api import v1_models, v1_processing
1415
from codegate.db.connection import AlreadyExistsError, DbReader
15-
from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
16+
from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel
1617
from codegate.muxing.persona import (
1718
PersonaDoesNotExistError,
1819
PersonaManager,
@@ -443,11 +444,11 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
443444
raise HTTPException(status_code=500, detail="Internal server error")
444445

445446
try:
446-
summary = await dbreader.get_alerts_summary_by_workspace(ws.id)
447+
summary = await dbreader.get_alerts_summary(workspace_id=ws.id)
447448
return v1_models.AlertSummary(
448-
malicious_packages=summary["codegate_context_retriever_count"],
449-
pii=summary["codegate_pii_count"],
450-
secrets=summary["codegate_secrets_count"],
449+
malicious_packages=summary.total_packages_count,
450+
pii=summary.total_pii_count,
451+
secrets=summary.total_secrets_count,
451452
)
452453
except Exception:
453454
logger.exception("Error while getting alerts summary")
@@ -459,7 +460,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
459460
tags=["Workspaces"],
460461
generate_unique_id_function=uniq_name,
461462
)
462-
async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]:
463+
async def get_workspace_messages(
464+
workspace_name: str,
465+
page: int = Query(1, ge=1),
466+
page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE),
467+
filter_by_ids: Optional[List[str]] = Query(None),
468+
filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None),
469+
) -> v1_models.PaginatedMessagesResponse:
463470
"""Get messages for a workspace."""
464471
try:
465472
ws = await wscrud.get_workspace_by_name(workspace_name)
@@ -469,20 +476,89 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
469476
logger.exception("Error while getting workspace")
470477
raise HTTPException(status_code=500, detail="Internal server error")
471478

472-
try:
473-
prompts_with_output_alerts_usage = (
474-
await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(
475-
ws.id, AlertSeverity.CRITICAL.value
476-
)
479+
offset = (page - 1) * page_size
480+
481+
prompts = await dbreader.get_prompts(
482+
ws.id,
483+
offset,
484+
page_size,
485+
filter_by_ids,
486+
list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity
487+
filter_by_alert_trigger_types,
488+
)
489+
# Fetch total message count
490+
total_count = await dbreader.get_total_messages_count_by_workspace_id(
491+
ws.id, AlertSeverity.CRITICAL.value
492+
)
493+
494+
# iterate for all prompts to compose the conversation summary
495+
conversation_summaries: List[v1_models.ConversationSummary] = []
496+
for prompt in prompts:
497+
if not prompt.request:
498+
logger.warning(f"Skipping prompt {prompt.id}. Empty request field")
499+
continue
500+
501+
messages, _ = await v1_processing.parse_request(prompt.request)
502+
if not messages or len(messages) == 0:
503+
logger.warning(f"Skipping prompt {prompt.id}. No messages found")
504+
continue
505+
506+
# message is just the first entry in the request
507+
message_obj = v1_models.ChatMessage(
508+
message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id
477509
)
478-
conversations, _ = await v1_processing.parse_messages_in_conversations(
479-
prompts_with_output_alerts_usage
510+
511+
# count total alerts for the prompt
512+
total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id)
513+
514+
# get token usage for the prompt
515+
prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id)
516+
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
517+
518+
conversation_summary = v1_models.ConversationSummary(
519+
chat_id=prompt.id,
520+
prompt=message_obj,
521+
provider=prompt.provider,
522+
type=prompt.type,
523+
conversation_timestamp=prompt.timestamp,
524+
total_alerts=total_alerts_row.total_alerts,
525+
token_usage_agg=ws_token_usage,
480526
)
481-
return conversations
527+
528+
conversation_summaries.append(conversation_summary)
529+
530+
return v1_models.PaginatedMessagesResponse(
531+
data=conversation_summaries,
532+
limit=page_size,
533+
offset=(page - 1) * page_size,
534+
total=total_count,
535+
)
536+
537+
538+
@v1.get(
539+
"/workspaces/{workspace_name}/messages/{prompt_id}",
540+
tags=["Workspaces"],
541+
generate_unique_id_function=uniq_name,
542+
)
543+
async def get_messages_by_prompt_id(
544+
workspace_name: str,
545+
prompt_id: str,
546+
) -> List[v1_models.Conversation]:
547+
"""Get messages for a workspace."""
548+
try:
549+
ws = await wscrud.get_workspace_by_name(workspace_name)
550+
except crud.WorkspaceDoesNotExistError:
551+
raise HTTPException(status_code=404, detail="Workspace does not exist")
482552
except Exception:
483-
logger.exception("Error while getting messages")
553+
logger.exception("Error while getting workspace")
484554
raise HTTPException(status_code=500, detail="Internal server error")
485555

556+
prompts_outputs = await dbreader.get_prompts_with_output(
557+
workspace_id=ws.id, prompt_id=prompt_id
558+
)
559+
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
560+
return conversations
561+
486562

487563
@v1.get(
488564
"/workspaces/{workspace_name}/custom-instructions",
@@ -665,7 +741,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
665741
raise HTTPException(status_code=500, detail="Internal server error")
666742

667743
try:
668-
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
744+
prompts_outputs = await dbreader.get_prompts_with_output(worskpace_id=ws.id)
669745
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
670746
return ws_token_usage
671747
except Exception:

src/codegate/api/v1_models.py

+21
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,20 @@ class Conversation(pydantic.BaseModel):
218218
alerts: List[Alert] = []
219219

220220

221+
class ConversationSummary(pydantic.BaseModel):
222+
"""
223+
Represents a conversation summary.
224+
"""
225+
226+
chat_id: str
227+
prompt: ChatMessage
228+
total_alerts: int
229+
token_usage_agg: Optional[TokenUsageAggregate]
230+
provider: Optional[str]
231+
type: QuestionType
232+
conversation_timestamp: datetime.datetime
233+
234+
221235
class AlertConversation(pydantic.BaseModel):
222236
"""
223237
Represents an alert with it's respective conversation.
@@ -333,3 +347,10 @@ class PersonaUpdateRequest(pydantic.BaseModel):
333347

334348
new_name: str
335349
new_description: str
350+
351+
352+
class PaginatedMessagesResponse(pydantic.BaseModel):
353+
data: List[ConversationSummary]
354+
limit: int
355+
offset: int
356+
total: int

src/codegate/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"llamacpp": "./codegate_volume/models", # Default LlamaCpp model path
2626
}
2727

28+
API_DEFAULT_PAGE_SIZE = 50
29+
API_MAX_PAGE_SIZE = 100
30+
2831

2932
@dataclass
3033
class Config:

0 commit comments

Comments
 (0)