Skip to content

feat: add pagination to alerts and messages endpoints #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import requests
import structlog
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

import codegate.muxing.models as mux_models
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.db.models import AlertSeverity, WorkspaceWithModel
from codegate.providers import crud as provendcrud
Expand Down Expand Up @@ -402,7 +403,12 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
)
async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]:
async def get_workspace_messages(
workspace_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE),
filter_by_ids: Optional[List[str]] = Query(None),
) -> v1_models.PaginatedMessagesResponse:
"""Get messages for a workspace."""
try:
ws = await wscrud.get_workspace_by_name(workspace_name)
Expand All @@ -412,19 +418,33 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

try:
prompts_with_output_alerts_usage = (
await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(
ws.id, AlertSeverity.CRITICAL.value
)
offset = (page - 1) * page_size
fetched_messages = []

while len(fetched_messages) < page_size:
messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(
ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids
)
conversations, _ = await v1_processing.parse_messages_in_conversations(
prompts_with_output_alerts_usage
if not messages_batch:
break
parsed_conversations, _ = await v1_processing.parse_messages_in_conversations(
messages_batch
)
return conversations
except Exception:
logger.exception("Error while getting messages")
raise HTTPException(status_code=500, detail="Internal server error")
fetched_messages.extend(parsed_conversations)
offset += page_size

final_messages = fetched_messages[:page_size]

# Fetch total message count
total_count = await dbreader.get_total_messages_count_by_workspace_id(
ws.id, AlertSeverity.CRITICAL.value
)
return v1_models.PaginatedMessagesResponse(
data=final_messages,
limit=page_size,
offset=(page - 1) * page_size,
total=total_count,
)


@v1.get(
Expand Down
7 changes: 7 additions & 0 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,10 @@ class ModelByProvider(pydantic.BaseModel):

def __str__(self):
return f"{self.provider_name} / {self.name}"


class PaginatedMessagesResponse(pydantic.BaseModel):
data: List[Conversation]
limit: int
offset: int
total: int
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"llamacpp": "./codegate_volume/models", # Default LlamaCpp model path
}

API_DEFAULT_PAGE_SIZE = 50
API_MAX_PAGE_SIZE = 100


@dataclass
class Config:
Expand Down
86 changes: 68 additions & 18 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel
from sqlalchemy import CursorResult, TextClause, event, text
from sqlalchemy import CursorResult, TextClause, bindparam, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.config import API_DEFAULT_PAGE_SIZE
from codegate.db.fim_cache import FimCache
from codegate.db.models import (
ActiveWorkspace,
Expand Down Expand Up @@ -594,36 +595,58 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO
return prompts

async def get_prompts_with_output_alerts_usage_by_workspace_id(
self, workspace_id: str, trigger_category: Optional[str] = None
self,
workspace_id: str,
trigger_category: Optional[str] = None,
limit: int = API_DEFAULT_PAGE_SIZE,
offset: int = 0,
filter_by_ids: Optional[List[str]] = None,
) -> List[GetPromptWithOutputsRow]:
"""
Get all prompts with their outputs, alerts and token usage by workspace_id.
"""

sql = text(
"""
base_query = """
SELECT
p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type,
o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost,
a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp
p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type,
o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost,
a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp
FROM prompts p
LEFT JOIN outputs o ON p.id = o.prompt_id
LEFT JOIN alerts a ON p.id = a.prompt_id
WHERE p.workspace_id = :workspace_id
AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL)
AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category)
""" # noqa: E501

if filter_by_ids:
base_query += " AND p.id IN :filter_ids"

base_query += """
ORDER BY o.timestamp DESC, a.timestamp DESC
""" # noqa: E501
)
# If trigger category is None we want to get all alerts
trigger_category = trigger_category if trigger_category else "%"
conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category}
rows: List[IntermediatePromptWithOutputUsageAlerts] = (
await self._exec_select_conditions_to_pydantic(
IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True
)
LIMIT :limit OFFSET :offset
"""

sql = text(base_query)

conditions = {
"workspace_id": workspace_id,
"trigger_category": trigger_category,
"limit": limit,
"offset": offset,
}

if filter_by_ids:
sql = sql.bindparams(bindparam("filter_ids", expanding=True))
conditions["filter_ids"] = filter_by_ids

fetched_rows: List[
IntermediatePromptWithOutputUsageAlerts
] = await self._exec_select_conditions_to_pydantic(
IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True
)

prompts_dict: Dict[str, GetPromptWithOutputsRow] = {}
for row in rows:
for row in fetched_rows:
prompt_id = row.prompt_id
if prompt_id not in prompts_dict:
prompts_dict[prompt_id] = GetPromptWithOutputsRow(
Expand Down Expand Up @@ -655,6 +678,33 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id(

return list(prompts_dict.values())

async def get_total_messages_count_by_workspace_id(
self, workspace_id: str, trigger_category: Optional[str] = None
) -> int:
"""Get total count of messages for a given workspace_id, considering trigger_category."""
sql = text(
"""
SELECT COUNT(*)
FROM prompts p
LEFT JOIN alerts a ON p.id = a.prompt_id
WHERE p.workspace_id = :workspace_id
"""
)
conditions = {"workspace_id": workspace_id}

if trigger_category:
sql = text(sql.text + " AND a.trigger_category = :trigger_category")
conditions["trigger_category"] = trigger_category

async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql, conditions)
count = result.scalar() # Fetches the integer result directly
return count or 0 # Ensure it returns an integer
except Exception as e:
logger.error(f"Failed to fetch message count. Error: {e}")
return 0 # Return 0 in case of failure

async def get_alerts_by_workspace(
self, workspace_id: str, trigger_category: Optional[str] = None
) -> List[Alert]:
Expand Down
Loading