diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ebd9be79..00fe45c2 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,7 +3,7 @@ import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError @@ -11,6 +11,7 @@ import codegate.muxing.models as mux_models from codegate import __version__ from codegate.api import v1_models, v1_processing +from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE from codegate.db.connection import AlreadyExistsError, DbReader from codegate.db.models import AlertSeverity, WorkspaceWithModel from codegate.providers import crud as provendcrud @@ -402,7 +403,12 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: +async def get_workspace_messages( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), +) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -412,19 +418,33 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - try: - prompts_with_output_alerts_usage = ( - await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value - ) + offset = (page - 1) * page_size + fetched_messages = [] + + while len(fetched_messages) < page_size: + messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids ) - conversations, _ = await v1_processing.parse_messages_in_conversations( - prompts_with_output_alerts_usage + if not messages_batch: + break + parsed_conversations, _ = await v1_processing.parse_messages_in_conversations( + messages_batch ) - return conversations - except Exception: - logger.exception("Error while getting messages") - raise HTTPException(status_code=500, detail="Internal server error") + fetched_messages.extend(parsed_conversations) + offset += page_size + + final_messages = fetched_messages[:page_size] + + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, AlertSeverity.CRITICAL.value + ) + return v1_models.PaginatedMessagesResponse( + data=final_messages, + limit=page_size, + offset=(page - 1) * page_size, + total=total_count, + ) @v1.get( diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index c608484c..8ce9e2bc 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -312,3 +312,10 @@ class ModelByProvider(pydantic.BaseModel): def __str__(self): return f"{self.provider_name} / {self.name}" + + +class PaginatedMessagesResponse(pydantic.BaseModel): + data: List[Conversation] + limit: int + offset: int + total: int diff --git a/src/codegate/config.py b/src/codegate/config.py index 11cd96bf..8f9a15c5 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -25,6 +25,9 @@ "llamacpp": "./codegate_volume/models", # Default LlamaCpp model path } +API_DEFAULT_PAGE_SIZE = 50 +API_MAX_PAGE_SIZE = 100 + @dataclass class Config: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 2d56fccd..9dc67b8b 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,11 +8,12 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine +from codegate.config import API_DEFAULT_PAGE_SIZE from codegate.db.fim_cache import FimCache from codegate.db.models import ( ActiveWorkspace, @@ -594,36 +595,58 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO return prompts async def get_prompts_with_output_alerts_usage_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None + self, + workspace_id: str, + trigger_category: Optional[str] = None, + limit: int = API_DEFAULT_PAGE_SIZE, + offset: int = 0, + filter_by_ids: Optional[List[str]] = None, ) -> List[GetPromptWithOutputsRow]: """ Get all prompts with their outputs, alerts and token usage by workspace_id. """ - sql = text( - """ + base_query = """ SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp + p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, + o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, + a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL) + AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category) + """ # noqa: E501 + + if filter_by_ids: + base_query += " AND p.id IN :filter_ids" + + base_query += """ ORDER BY o.timestamp DESC, a.timestamp DESC - """ # noqa: E501 - ) - # If trigger category is None we want to get all alerts - trigger_category = trigger_category if trigger_category else "%" - conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + LIMIT :limit OFFSET :offset + """ + + sql = text(base_query) + + conditions = { + "workspace_id": workspace_id, + "trigger_category": trigger_category, + "limit": limit, + "offset": offset, + } + + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_ids", expanding=True)) + conditions["filter_ids"] = filter_by_ids + + fetched_rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) + prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} - for row in rows: + for row in fetched_rows: prompt_id = row.prompt_id if prompt_id not in prompts_dict: prompts_dict[prompt_id] = GetPromptWithOutputsRow( @@ -655,6 +678,33 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( return list(prompts_dict.values()) + async def get_total_messages_count_by_workspace_id( + self, workspace_id: str, trigger_category: Optional[str] = None + ) -> int: + """Get total count of messages for a given workspace_id, considering trigger_category.""" + sql = text( + """ + SELECT COUNT(*) + FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + """ + ) + conditions = {"workspace_id": workspace_id} + + if trigger_category: + sql = text(sql.text + " AND a.trigger_category = :trigger_category") + conditions["trigger_category"] = trigger_category + + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql, conditions) + count = result.scalar() # Fetches the integer result directly + return count or 0 # Ensure it returns an integer + except Exception as e: + logger.error(f"Failed to fetch message count. Error: {e}") + return 0 # Return 0 in case of failure + async def get_alerts_by_workspace( self, workspace_id: str, trigger_category: Optional[str] = None ) -> List[Alert]: