From c1b9c13bd1d363123bd5356b82f7e070de110966 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 12:03:59 +0000 Subject: [PATCH 1/3] feat: filter messages by ID --- src/codegate/api/v1.py | 3 ++- src/codegate/db/connection.py | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 96672615..00fe45c2 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -407,6 +407,7 @@ async def get_workspace_messages( workspace_name: str, page: int = Query(1, ge=1), page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), ) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: @@ -422,7 +423,7 @@ async def get_workspace_messages( while len(fetched_messages) < page_size: messages_batch = await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value, page_size, offset + ws.id, AlertSeverity.CRITICAL.value, page_size, offset, filter_by_ids ) if not messages_batch: break diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 32d6e24a..6cab37c2 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -600,6 +600,7 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( trigger_category: Optional[str] = None, limit: int = API_DEFAULT_PAGE_SIZE, offset: int = 0, + filter_by_ids: Optional[List[str]] = None, ) -> List[GetPromptWithOutputsRow]: """ Get all prompts with their outputs, alerts and token usage by workspace_id. @@ -621,16 +622,22 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( sql = text(sql.text + " AND a.trigger_category = :trigger_category") conditions["trigger_category"] = trigger_category + if filter_by_ids: + placeholders = ", ".join([":filter_by_id_" + str(i) for i in range(len(filter_by_ids))]) + sql = text(sql.text + f" AND p.id IN ({placeholders})") + for i, filter_id in enumerate(filter_by_ids): + conditions[f"filter_by_id_{i}"] = filter_id + sql = text( sql.text + " ORDER BY o.timestamp DESC, a.timestamp DESC LIMIT :limit OFFSET :offset" ) conditions["limit"] = limit conditions["offset"] = offset - fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + fetched_rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: From 84f247adce437c16b7b8d21de3b649a533e133a8 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 12:08:22 +0000 Subject: [PATCH 2/3] lint fix --- src/codegate/db/connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 6cab37c2..a16140b8 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, bindparam, event, text +from sqlalchemy import CursorResult, TextClause, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -634,10 +634,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( conditions["limit"] = limit conditions["offset"] = offset - fetched_rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: From c4b994e17c1408287be66ffe3deaaa3dbd83156f Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 4 Mar 2025 14:40:13 +0000 Subject: [PATCH 3/3] fix: use `.bindparams` for `filter_by_ids` --- src/codegate/db/connection.py | 56 +++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index a16140b8..9dc67b8b 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -8,7 +8,7 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import create_async_engine @@ -605,40 +605,46 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( """ Get all prompts with their outputs, alerts and token usage by workspace_id. """ - sql = text( - """ + + base_query = """ SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp + p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, + o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, + a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - """ # noqa: E501 - ) - conditions = {"workspace_id": workspace_id} - if trigger_category: - sql = text(sql.text + " AND a.trigger_category = :trigger_category") - conditions["trigger_category"] = trigger_category + AND (:trigger_category IS NULL OR a.trigger_category = :trigger_category) + """ # noqa: E501 if filter_by_ids: - placeholders = ", ".join([":filter_by_id_" + str(i) for i in range(len(filter_by_ids))]) - sql = text(sql.text + f" AND p.id IN ({placeholders})") - for i, filter_id in enumerate(filter_by_ids): - conditions[f"filter_by_id_{i}"] = filter_id + base_query += " AND p.id IN :filter_ids" - sql = text( - sql.text + " ORDER BY o.timestamp DESC, a.timestamp DESC LIMIT :limit OFFSET :offset" - ) - conditions["limit"] = limit - conditions["offset"] = offset + base_query += """ + ORDER BY o.timestamp DESC, a.timestamp DESC + LIMIT :limit OFFSET :offset + """ - fetched_rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + sql = text(base_query) + + conditions = { + "workspace_id": workspace_id, + "trigger_category": trigger_category, + "limit": limit, + "offset": offset, + } + + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_ids", expanding=True)) + conditions["filter_ids"] = filter_by_ids + + fetched_rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) + prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in fetched_rows: prompt_id = row.prompt_id