diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 96672615..00fe45c2 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -407,6 +407,7 @@ async def get_workspace_messages( workspace_name: str, page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), ) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: @@ -422,7 +423,7 @@ async def get_workspace_messages( while len(fetched_messages) < page_size: messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value, page_size, offset + ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids ) if not messages_batch: break diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 32d6e24a..9dc67b8b 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -600,38 +600,51 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( trigger_category: Optional[str] = None, limit: int = API_DEFAULT_PAGE_SIZE, offset: int = 0, + filter_by_ids: Optional[List[str]] = None, ) -> List[GetPromptWithOutputsRow]: """ Get all prompts with their outputs, alerts and token usage by workspace_id. """ - sql = text( - """ + + base_query = """ SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp + p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, + o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, + a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - """ # noqa: E501 - ) - conditions = {"workspace_id": workspace_id} - if trigger_category: - sql = text(sql.text + " AND a.trigger_category = :trigger_category") - conditions["trigger_category"] = trigger_category + AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category) + """ # noqa: E501 - sql = text( - sql.text + " ORDER BY o.timestamp DESC, a.timestamp DESC LIMIT :limit OFFSET :offset" - ) - conditions["limit"] = limit - conditions["offset"] = offset + if filter_by_ids: + base_query += " AND p.id IN :filter_ids" - fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + base_query += """ + ORDER BY o.timestamp DESC, a.timestamp DESC + LIMIT :limit OFFSET :offset + """ + + sql = text(base_query) + + conditions = { + "workspace_id": workspace_id, + "trigger_category": trigger_category, + "limit": limit, + "offset": offset, + } + + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_ids", expanding=True)) + conditions["filter_ids"] = filter_by_ids + + fetched_rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) + prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: prompt_id = row.prompt_id