Skip to content

feat: update messages endpoint to return a conversation summary #1247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 11, 2025
148 changes: 129 additions & 19 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import requests
import structlog
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
import codegate.muxing.models as mux_models
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel
from codegate.muxing.persona import (
PersonaDoesNotExistError,
PersonaManager,
Expand Down Expand Up @@ -419,7 +420,9 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
raise HTTPException(status_code=500, detail="Internal server error")

try:
alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value)
alerts = await dbreader.get_alerts_by_workspace_or_prompt_id(
workspace_id=ws.id, trigger_category=AlertSeverity.CRITICAL.value
)
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs)
except Exception:
Expand All @@ -443,11 +446,12 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
raise HTTPException(status_code=500, detail="Internal server error")

try:
summary = await dbreader.get_alerts_summary_by_workspace(ws.id)
summary = await dbreader.get_alerts_summary(workspace_id=ws.id)
return v1_models.AlertSummary(
malicious_packages=summary["codegate_context_retriever_count"],
pii=summary["codegate_pii_count"],
secrets=summary["codegate_secrets_count"],
malicious_packages=summary.total_packages_count,
pii=summary.total_pii_count,
secrets=summary.total_secrets_count,
total_alerts=summary.total_alerts,
)
except Exception:
logger.exception("Error while getting alerts summary")
Expand All @@ -459,7 +463,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
)
async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]:
async def get_workspace_messages(
workspace_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE),
filter_by_ids: Optional[List[str]] = Query(None),
filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None),
) -> v1_models.PaginatedMessagesResponse:
"""Get messages for a workspace."""
try:
ws = await wscrud.get_workspace_by_name(workspace_name)
Expand All @@ -469,19 +479,119 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

try:
prompts_with_output_alerts_usage = (
await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(
ws.id, AlertSeverity.CRITICAL.value
)
)
conversations, _ = await v1_processing.parse_messages_in_conversations(
prompts_with_output_alerts_usage
offset = (page - 1) * page_size
valid_conversations: List[v1_models.ConversationSummary] = []
fetched_prompts = 0

while len(valid_conversations) < page_size:
batch_size = page_size * 2 # Fetch more prompts to compensate for potential skips

prompts = await dbreader.get_prompts(
ws.id,
offset + fetched_prompts,
batch_size,
filter_by_ids,
list([AlertSeverity.CRITICAL.value]),
filter_by_alert_trigger_types,
)
return conversations

if not prompts or len(prompts) == 0:
break

# iterate for all prompts to compose the conversation summary
for prompt in prompts:
fetched_prompts += 1
if not prompt.request:
logger.warning(f"Skipping prompt {prompt.id}. Empty request field")
continue

messages, _ = await v1_processing.parse_request(prompt.request)
if not messages or len(messages) == 0:
logger.warning(f"Skipping prompt {prompt.id}. No messages found")
continue

# message is just the first entry in the request, cleaned properly
message = v1_processing.parse_question_answer(messages[0])
message_obj = v1_models.ChatMessage(
message=message, timestamp=prompt.timestamp, message_id=prompt.id
)

# count total alerts for the prompt
total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id)

# get token usage for the prompt
prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id)
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)

conversation_summary = v1_models.ConversationSummary(
chat_id=prompt.id,
prompt=message_obj,
provider=prompt.provider,
type=prompt.type,
conversation_timestamp=prompt.timestamp,
alerts_summary=v1_models.AlertSummary(
malicious_packages=total_alerts_row.total_packages_count,
pii=total_alerts_row.total_pii_count,
secrets=total_alerts_row.total_secrets_count,
total_alerts=total_alerts_row.total_alerts,
),
total_alerts=total_alerts_row.total_alerts,
token_usage_agg=ws_token_usage,
)

valid_conversations.append(conversation_summary)
if len(valid_conversations) >= page_size:
break

# Fetch total message count
total_count = await dbreader.get_total_messages_count_by_workspace_id(
ws.id,
filter_by_ids,
list([AlertSeverity.CRITICAL.value]),
filter_by_alert_trigger_types,
)

return v1_models.PaginatedMessagesResponse(
data=valid_conversations,
limit=page_size,
offset=offset,
total=total_count,
)


@v1.get(
"/workspaces/{workspace_name}/messages/{prompt_id}",
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
)
async def get_messages_by_prompt_id(
workspace_name: str,
prompt_id: str,
) -> v1_models.Conversation:
"""Get messages for a workspace."""
try:
ws = await wscrud.get_workspace_by_name(workspace_name)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
logger.exception("Error while getting messages")
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")
prompts_outputs = await dbreader.get_prompts_with_output(
workspace_id=ws.id, prompt_id=prompt_id
)

# get all alerts for the prompt
alerts = await dbreader.get_alerts_by_workspace_or_prompt_id(
workspace_id=ws.id, prompt_id=prompt_id, trigger_category=AlertSeverity.CRITICAL.value
)
deduped_alerts = await v1_processing.remove_duplicate_alerts(alerts)
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
if not conversations:
raise HTTPException(status_code=404, detail="Conversation not found")

conversation = conversations[0]
conversation.alerts = deduped_alerts
return conversation


@v1.get(
Expand Down Expand Up @@ -665,7 +775,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
raise HTTPException(status_code=500, detail="Internal server error")

try:
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
prompts_outputs = await dbreader.get_prompts_with_output(workspace_id=ws.id)
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
return ws_token_usage
except Exception:
Expand Down
25 changes: 23 additions & 2 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class AlertSummary(pydantic.BaseModel):
malicious_packages: int
pii: int
secrets: int
total_alerts: int


class PartialQuestionAnswer(pydantic.BaseModel):
Expand All @@ -201,7 +202,6 @@ class PartialQuestionAnswer(pydantic.BaseModel):
partial_questions: PartialQuestions
answer: Optional[ChatMessage]
model_token_usage: TokenUsageByModel
alerts: List[Alert] = []


class Conversation(pydantic.BaseModel):
Expand All @@ -215,7 +215,21 @@ class Conversation(pydantic.BaseModel):
chat_id: str
conversation_timestamp: datetime.datetime
token_usage_agg: Optional[TokenUsageAggregate]
alerts: List[Alert] = []
alerts: Optional[List[Alert]] = []


class ConversationSummary(pydantic.BaseModel):
"""
Represents a conversation summary.
"""

chat_id: str
prompt: ChatMessage
alerts_summary: AlertSummary
token_usage_agg: Optional[TokenUsageAggregate]
provider: Optional[str]
type: QuestionType
conversation_timestamp: datetime.datetime


class AlertConversation(pydantic.BaseModel):
Expand Down Expand Up @@ -333,3 +347,10 @@ class PersonaUpdateRequest(pydantic.BaseModel):

new_name: str
new_description: str


class PaginatedMessagesResponse(pydantic.BaseModel):
data: List[ConversationSummary]
limit: int
offset: int
total: int
10 changes: 1 addition & 9 deletions src/codegate/api/v1_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,10 @@ async def _get_partial_question_answer(
model=model, token_usage=token_usage, provider_type=provider
)

alerts: List[v1_models.Alert] = [
v1_models.Alert.from_db_model(db_alert) for db_alert in row.alerts
]

return PartialQuestionAnswer(
partial_questions=request_message,
answer=output_message,
model_token_usage=model_token_usage,
alerts=alerts,
)


Expand Down Expand Up @@ -374,7 +369,7 @@ async def match_conversations(
for group in grouped_partial_questions:
questions_answers: List[QuestionAnswer] = []
token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage())
alerts: List[v1_models.Alert] = []

first_partial_qa = None
for partial_question in sorted(group, key=lambda x: x.timestamp):
# Partial questions don't contain the answer, so we need to find the corresponding
Expand All @@ -398,8 +393,6 @@ async def match_conversations(
qa = _get_question_answer_from_partial(selected_partial_qa)
qa.question.message = parse_question_answer(qa.question.message)
questions_answers.append(qa)
deduped_alerts = await remove_duplicate_alerts(selected_partial_qa.alerts)
alerts.extend(deduped_alerts)
token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage)

# if we have a conversation with at least one question and answer
Expand All @@ -413,7 +406,6 @@ async def match_conversations(
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
token_usage_agg=token_usage_agg,
alerts=alerts,
)
for qa in questions_answers:
map_q_id_to_conversation[qa.question.message_id] = conversation
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"llamacpp": "./codegate_volume/models", # Default LlamaCpp model path
}

API_DEFAULT_PAGE_SIZE = 50
API_MAX_PAGE_SIZE = 100


@dataclass
class Config:
Expand Down
Loading
Loading