diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 33efea33..5db3ae54 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,16 +3,17 @@ import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError +from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE import codegate.muxing.models as mux_models from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel +from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel from codegate.muxing.persona import ( PersonaDoesNotExistError, PersonaManager, @@ -419,7 +420,9 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.A raise HTTPException(status_code=500, detail="Internal server error") try: - alerts = await dbreader.get_alerts_by_workspace(ws.id, AlertSeverity.CRITICAL.value) + alerts = await dbreader.get_alerts_by_workspace_or_prompt_id( + workspace_id=ws.id, trigger_category=AlertSeverity.CRITICAL.value + ) prompts_outputs = await dbreader.get_prompts_with_output(ws.id) return await v1_processing.parse_get_alert_conversation(alerts, prompts_outputs) except Exception: @@ -443,11 +446,12 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu raise HTTPException(status_code=500, detail="Internal server error") try: - summary = await dbreader.get_alerts_summary_by_workspace(ws.id) + summary = await dbreader.get_alerts_summary(workspace_id=ws.id) return v1_models.AlertSummary( - malicious_packages=summary["codegate_context_retriever_count"], - pii=summary["codegate_pii_count"], - secrets=summary["codegate_secrets_count"], + malicious_packages=summary.total_packages_count, + pii=summary.total_pii_count, + secrets=summary.total_secrets_count, + total_alerts=summary.total_alerts, ) except Exception: logger.exception("Error while getting alerts summary") @@ -459,7 +463,13 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: +async def get_workspace_messages( + workspace_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(API_DEFAULT_PAGE_SIZE, ge=1, le=API_MAX_PAGE_SIZE), + filter_by_ids: Optional[List[str]] = Query(None), + filter_by_alert_trigger_types: Optional[List[AlertTriggerType]] = Query(None), +) -> v1_models.PaginatedMessagesResponse: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -469,19 +479,119 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - try: - prompts_with_output_alerts_usage = ( - await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id( - ws.id, AlertSeverity.CRITICAL.value - ) - ) - conversations, _ = await v1_processing.parse_messages_in_conversations( - prompts_with_output_alerts_usage + offset = (page - 1) * page_size + valid_conversations: List[v1_models.ConversationSummary] = [] + fetched_prompts = 0 + + while len(valid_conversations) < page_size: + batch_size = page_size * 2 # Fetch more prompts to compensate for potential skips + + prompts = await dbreader.get_prompts( + ws.id, + offset + fetched_prompts, + batch_size, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), + filter_by_alert_trigger_types, ) - return conversations + + if not prompts or len(prompts) == 0: + break + + # iterate for all prompts to compose the conversation summary + for prompt in prompts: + fetched_prompts += 1 + if not prompt.request: + logger.warning(f"Skipping prompt {prompt.id}. Empty request field") + continue + + messages, _ = await v1_processing.parse_request(prompt.request) + if not messages or len(messages) == 0: + logger.warning(f"Skipping prompt {prompt.id}. No messages found") + continue + + # message is just the first entry in the request, cleaned properly + message = v1_processing.parse_question_answer(messages[0]) + message_obj = v1_models.ChatMessage( + message=message, timestamp=prompt.timestamp, message_id=prompt.id + ) + + # count total alerts for the prompt + total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id) + + # get token usage for the prompt + prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id) + ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) + + conversation_summary = v1_models.ConversationSummary( + chat_id=prompt.id, + prompt=message_obj, + provider=prompt.provider, + type=prompt.type, + conversation_timestamp=prompt.timestamp, + alerts_summary=v1_models.AlertSummary( + malicious_packages=total_alerts_row.total_packages_count, + pii=total_alerts_row.total_pii_count, + secrets=total_alerts_row.total_secrets_count, + total_alerts=total_alerts_row.total_alerts, + ), + total_alerts=total_alerts_row.total_alerts, + token_usage_agg=ws_token_usage, + ) + + valid_conversations.append(conversation_summary) + if len(valid_conversations) >= page_size: + break + + # Fetch total message count + total_count = await dbreader.get_total_messages_count_by_workspace_id( + ws.id, + filter_by_ids, + list([AlertSeverity.CRITICAL.value]), + filter_by_alert_trigger_types, + ) + + return v1_models.PaginatedMessagesResponse( + data=valid_conversations, + limit=page_size, + offset=offset, + total=total_count, + ) + + +@v1.get( + "/workspaces/{workspace_name}/messages/{prompt_id}", + tags=["Workspaces"], + generate_unique_id_function=uniq_name, +) +async def get_messages_by_prompt_id( + workspace_name: str, + prompt_id: str, +) -> v1_models.Conversation: + """Get messages for a workspace.""" + try: + ws = await wscrud.get_workspace_by_name(workspace_name) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: - logger.exception("Error while getting messages") + logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") + prompts_outputs = await dbreader.get_prompts_with_output( + workspace_id=ws.id, prompt_id=prompt_id + ) + + # get all alerts for the prompt + alerts = await dbreader.get_alerts_by_workspace_or_prompt_id( + workspace_id=ws.id, prompt_id=prompt_id, trigger_category=AlertSeverity.CRITICAL.value + ) + deduped_alerts = await v1_processing.remove_duplicate_alerts(alerts) + conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs) + if not conversations: + raise HTTPException(status_code=404, detail="Conversation not found") + + conversation = conversations[0] + conversation.alerts = deduped_alerts + return conversation @v1.get( @@ -665,7 +775,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage raise HTTPException(status_code=500, detail="Internal server error") try: - prompts_outputs = await dbreader.get_prompts_with_output(ws.id) + prompts_outputs = await dbreader.get_prompts_with_output(workspace_id=ws.id) ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs) return ws_token_usage except Exception: diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index dff26489..6489f96d 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -191,6 +191,7 @@ class AlertSummary(pydantic.BaseModel): malicious_packages: int pii: int secrets: int + total_alerts: int class PartialQuestionAnswer(pydantic.BaseModel): @@ -201,7 +202,6 @@ class PartialQuestionAnswer(pydantic.BaseModel): partial_questions: PartialQuestions answer: Optional[ChatMessage] model_token_usage: TokenUsageByModel - alerts: List[Alert] = [] class Conversation(pydantic.BaseModel): @@ -215,7 +215,21 @@ class Conversation(pydantic.BaseModel): chat_id: str conversation_timestamp: datetime.datetime token_usage_agg: Optional[TokenUsageAggregate] - alerts: List[Alert] = [] + alerts: Optional[List[Alert]] = [] + + +class ConversationSummary(pydantic.BaseModel): + """ + Represents a conversation summary. + """ + + chat_id: str + prompt: ChatMessage + alerts_summary: AlertSummary + token_usage_agg: Optional[TokenUsageAggregate] + provider: Optional[str] + type: QuestionType + conversation_timestamp: datetime.datetime class AlertConversation(pydantic.BaseModel): @@ -333,3 +347,10 @@ class PersonaUpdateRequest(pydantic.BaseModel): new_name: str new_description: str + + +class PaginatedMessagesResponse(pydantic.BaseModel): + data: List[ConversationSummary] + limit: int + offset: int + total: int diff --git a/src/codegate/api/v1_processing.py b/src/codegate/api/v1_processing.py index 10f42075..07e5acaa 100644 --- a/src/codegate/api/v1_processing.py +++ b/src/codegate/api/v1_processing.py @@ -202,15 +202,10 @@ async def _get_partial_question_answer( model=model, token_usage=token_usage, provider_type=provider ) - alerts: List[v1_models.Alert] = [ - v1_models.Alert.from_db_model(db_alert) for db_alert in row.alerts - ] - return PartialQuestionAnswer( partial_questions=request_message, answer=output_message, model_token_usage=model_token_usage, - alerts=alerts, ) @@ -374,7 +369,7 @@ async def match_conversations( for group in grouped_partial_questions: questions_answers: List[QuestionAnswer] = [] token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage()) - alerts: List[v1_models.Alert] = [] + first_partial_qa = None for partial_question in sorted(group, key=lambda x: x.timestamp): # Partial questions don't contain the answer, so we need to find the corresponding @@ -398,8 +393,6 @@ async def match_conversations( qa = _get_question_answer_from_partial(selected_partial_qa) qa.question.message = parse_question_answer(qa.question.message) questions_answers.append(qa) - deduped_alerts = await remove_duplicate_alerts(selected_partial_qa.alerts) - alerts.extend(deduped_alerts) token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage) # if we have a conversation with at least one question and answer @@ -413,7 +406,6 @@ async def match_conversations( chat_id=first_partial_qa.partial_questions.message_id, conversation_timestamp=first_partial_qa.partial_questions.timestamp, token_usage_agg=token_usage_agg, - alerts=alerts, ) for qa in questions_answers: map_q_id_to_conversation[qa.question.message_id] = conversation diff --git a/src/codegate/config.py b/src/codegate/config.py index 179ec4d3..ee5cb168 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -25,6 +25,9 @@ "llamacpp": "./codegate_volume/models", # Default LlamaCpp model path } +API_DEFAULT_PAGE_SIZE = 50 +API_MAX_PAGE_SIZE = 100 + @dataclass class Config: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 3f439aea..915c4251 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -4,7 +4,7 @@ import sqlite3 import uuid from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import List, Optional, Tuple, Type import numpy as np import sqlite_vec_sl_tmp @@ -12,20 +12,22 @@ from alembic import command as alembic_command from alembic.config import Config as AlembicConfig from pydantic import BaseModel -from sqlalchemy import CursorResult, TextClause, event, text +from sqlalchemy import CursorResult, TextClause, bindparam, event, text from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from codegate.config import API_DEFAULT_PAGE_SIZE from codegate.db.fim_cache import FimCache from codegate.db.models import ( ActiveWorkspace, Alert, - GetPromptWithOutputsRow, + AlertSummaryRow, + AlertTriggerType, + GetMessagesRow, GetWorkspaceByNameConditions, Instance, - IntermediatePromptWithOutputUsageAlerts, MuxRule, Output, Persona, @@ -685,7 +687,11 @@ async def _exec_vec_db_query_to_pydantic( conn.close() return results - async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: + async def get_prompts_with_output( + self, workspace_id: Optional[str] = None, prompt_id: Optional[str] = None + ) -> List[GetMessagesRow]: + if not workspace_id and not prompt_id: + raise ValueError("Either workspace_id or prompt_id must be provided.") sql = text( """ SELECT @@ -699,80 +705,183 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO o.output_cost FROM prompts p LEFT JOIN outputs o ON p.id = o.prompt_id - WHERE p.workspace_id = :workspace_id + WHERE (:workspace_id IS NULL OR p.workspace_id = :workspace_id) + AND (:prompt_id IS NULL OR p.id = :prompt_id) ORDER BY o.timestamp DESC """ ) - conditions = {"workspace_id": workpace_id} + conditions = {"workspace_id": workspace_id, "prompt_id": prompt_id} prompts = await self._exec_select_conditions_to_pydantic( - GetPromptWithOutputsRow, sql, conditions, should_raise=True + GetMessagesRow, sql, conditions, should_raise=True ) return prompts - async def get_prompts_with_output_alerts_usage_by_workspace_id( - self, workspace_id: str, trigger_category: Optional[str] = None - ) -> List[GetPromptWithOutputsRow]: + def _build_prompt_query( + self, + base_query: str, + workspace_id: str, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + offset: Optional[int] = None, + page_size: Optional[int] = None, + ) -> Tuple[str, dict]: """ - Get all prompts with their outputs, alerts and token usage by workspace_id. + Helper method to construct SQL query and conditions for prompts based on filters. + + Args: + base_query: The base SQL query string with a placeholder for filter conditions. + workspace_id: The ID of the workspace to fetch prompts from. + filter_by_ids: Optional list of prompt IDs to filter by. + filter_by_alert_trigger_categories: Optional list of alert categories to filter by. + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by. + offset: Number of records to skip (for pagination). + page_size: Number of records per page. + + Returns: + A tuple containing the formatted SQL query string and a dictionary of conditions. """ + conditions = {"workspace_id": workspace_id} + filter_conditions = [] - sql = text( - """ - SELECT - p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type, - o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost, - a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp + if filter_by_alert_trigger_categories: + filter_conditions.append( + """AND (a.trigger_category IN :filter_by_alert_trigger_categories + OR a.trigger_category IS NULL)""" + ) + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories + + if filter_by_alert_trigger_types: + filter_conditions.append( + """AND EXISTS (SELECT 1 FROM alerts a2 WHERE + a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)""" + ) + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types + + if filter_by_ids: + filter_conditions.append("AND p.id IN :filter_by_ids") + conditions["filter_by_ids"] = filter_by_ids + + if offset is not None: + conditions["offset"] = offset + + if page_size is not None: + conditions["page_size"] = page_size + + filter_clause = " ".join(filter_conditions) + query = base_query.format(filter_conditions=filter_clause) + + return query, conditions + + async def get_prompts( + self, + workspace_id: str, + offset: int = 0, + page_size: int = API_DEFAULT_PAGE_SIZE, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + ) -> List[Prompt]: + """ + Retrieve prompts with filtering and pagination. + + Args: + workspace_id: The ID of the workspace to fetch prompts from + offset: Number of records to skip (for pagination) + page_size: Number of records per page + filter_by_ids: Optional list of prompt IDs to filter by + filter_by_alert_trigger_categories: Optional list of alert categories to filter by + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by + + Returns: + List of Prompt containing prompt details + """ + # Build base query + base_query = """ + SELECT DISTINCT p.id, p.timestamp, p.provider, p.request, p.type, + p.workspace_id FROM prompts p + LEFT JOIN alerts a ON p.id = a.prompt_id + WHERE p.workspace_id = :workspace_id + {filter_conditions} + ORDER BY p.timestamp DESC + LIMIT :page_size OFFSET :offset + """ + + query, conditions = self._build_prompt_query( + base_query, + workspace_id, + filter_by_ids, + filter_by_alert_trigger_categories, + filter_by_alert_trigger_types, + offset, + page_size, + ) + sql = text(query) + + # Bind optional params + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) + + # Execute query + rows = await self._exec_select_conditions_to_pydantic( + Prompt, sql, conditions, should_raise=True + ) + return rows + + async def get_total_messages_count_by_workspace_id( + self, + workspace_id: str, + filter_by_ids: Optional[List[str]] = None, + filter_by_alert_trigger_categories: Optional[List[str]] = None, + filter_by_alert_trigger_types: Optional[List[str]] = None, + ) -> int: + """ + Get total count of unique messages for a given workspace_id, + considering trigger_category. + """ + base_query = """ + SELECT COUNT(DISTINCT p.id) FROM prompts p - LEFT JOIN outputs o ON p.id = o.prompt_id LEFT JOIN alerts a ON p.id = a.prompt_id WHERE p.workspace_id = :workspace_id - AND (a.trigger_category = :trigger_category OR a.trigger_category is NULL) - ORDER BY o.timestamp DESC, a.timestamp DESC - """ # noqa: E501 - ) - # If trigger category is None we want to get all alerts - trigger_category = trigger_category if trigger_category else "%" - conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + {filter_conditions} + """ + + query, conditions = self._build_prompt_query( + base_query, + workspace_id, + filter_by_ids, + filter_by_alert_trigger_categories, + filter_by_alert_trigger_types, ) - prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} - for row in rows: - prompt_id = row.prompt_id - if prompt_id not in prompts_dict: - prompts_dict[prompt_id] = GetPromptWithOutputsRow( - id=row.prompt_id, - timestamp=row.prompt_timestamp, - provider=row.provider, - request=row.request, - type=row.type, - output_id=row.output_id, - output=row.output, - output_timestamp=row.output_timestamp, - input_tokens=row.input_tokens, - output_tokens=row.output_tokens, - input_cost=row.input_cost, - output_cost=row.output_cost, - alerts=[], - ) - if row.alert_id: - alert = Alert( - id=row.alert_id, - prompt_id=row.prompt_id, - code_snippet=row.code_snippet, - trigger_string=row.trigger_string, - trigger_type=row.trigger_type, - trigger_category=row.trigger_category, - timestamp=row.alert_timestamp, - ) - prompts_dict[prompt_id].alerts.append(alert) + sql = text(query) - return list(prompts_dict.values()) + # Bind optional params + if filter_by_alert_trigger_categories: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True)) + if filter_by_alert_trigger_types: + sql = sql.bindparams(bindparam("filter_by_alert_trigger_types", expanding=True)) + if filter_by_ids: + sql = sql.bindparams(bindparam("filter_by_ids", expanding=True)) - async def get_alerts_by_workspace( - self, workspace_id: str, trigger_category: Optional[str] = None + async with self._async_db_engine.begin() as conn: + try: + result = await conn.execute(sql, conditions) + count = result.scalar() # Fetches the integer result directly + return count or 0 # Ensure it returns an integer + except Exception as e: + logger.error(f"Failed to fetch message count. Error: {e}") + return 0 # Return 0 in case of failure + + async def get_alerts_by_workspace_or_prompt_id( + self, + workspace_id: str, + prompt_id: Optional[str] = None, + trigger_category: Optional[str] = None, ) -> List[Alert]: sql = text( """ @@ -791,6 +900,10 @@ async def get_alerts_by_workspace( ) conditions = {"workspace_id": workspace_id} + if prompt_id: + sql = text(sql.text + " AND a.prompt_id = :prompt_id") + conditions["prompt_id"] = prompt_id + if trigger_category: sql = text(sql.text + " AND a.trigger_category = :trigger_category") conditions["trigger_category"] = trigger_category @@ -802,37 +915,53 @@ async def get_alerts_by_workspace( ) return prompts - async def get_alerts_summary_by_workspace(self, workspace_id: str) -> dict: - """Get aggregated alert summary counts for a given workspace_id.""" + async def get_alerts_summary( + self, workspace_id: str = None, prompt_id: str = None + ) -> AlertSummaryRow: + """Get aggregated alert summary counts for a given workspace_id or prompt id.""" + if not workspace_id and not prompt_id: + raise ValueError("Either workspace_id or prompt_id must be provided.") + + filters = [] + conditions = {} + + if workspace_id: + filters.append("p.workspace_id = :workspace_id") + conditions["workspace_id"] = workspace_id + + if prompt_id: + filters.append("a.prompt_id = :prompt_id") + conditions["prompt_id"] = prompt_id + + filter_clause = " AND ".join(filters) + sql = text( - """ + f""" SELECT - COUNT(*) AS total_alerts, - SUM(CASE WHEN a.trigger_type = 'codegate-secrets' THEN 1 ELSE 0 END) - AS codegate_secrets_count, - SUM(CASE WHEN a.trigger_type = 'codegate-context-retriever' THEN 1 ELSE 0 END) - AS codegate_context_retriever_count, - SUM(CASE WHEN a.trigger_type = 'codegate-pii' THEN 1 ELSE 0 END) - AS codegate_pii_count + COUNT(*) AS total_alerts, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_SECRETS.value}' THEN 1 ELSE 0 END) + AS codegate_secrets_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_CONTEXT_RETRIEVER.value}' THEN 1 ELSE 0 END) + AS codegate_context_retriever_count, + SUM(CASE WHEN a.trigger_type = '{AlertTriggerType.CODEGATE_PII.value}' THEN 1 ELSE 0 END) + AS codegate_pii_count FROM alerts a INNER JOIN prompts p ON p.id = a.prompt_id - WHERE p.workspace_id = :workspace_id - """ + WHERE {filter_clause} + """ # noqa: E501 # nosec ) - conditions = {"workspace_id": workspace_id} - async with self._async_db_engine.begin() as conn: result = await conn.execute(sql, conditions) row = result.fetchone() # Return a dictionary with counts (handling None values safely) - return { - "codegate_secrets_count": row.codegate_secrets_count or 0 if row else 0, - "codegate_context_retriever_count": ( - row.codegate_context_retriever_count or 0 if row else 0 - ), - "codegate_pii_count": row.codegate_pii_count or 0 if row else 0, - } + + return AlertSummaryRow( + total_alerts=row.total_alerts or 0 if row else 0, + total_secrets_count=row.codegate_secrets_count or 0 if row else 0, + total_packages_count=row.codegate_context_retriever_count or 0 if row else 0, + total_pii_count=row.codegate_pii_count or 0 if row else 0, + ) async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 07c4c8ed..7f8ef434 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -115,6 +115,21 @@ class WorkspaceRow(BaseModel): custom_instructions: Optional[str] +class AlertSummaryRow(BaseModel): + """An alert summary row entry""" + + total_alerts: int + total_secrets_count: int + total_packages_count: int + total_pii_count: int + + +class AlertTriggerType(str, Enum): + CODEGATE_PII = "codegate-pii" + CODEGATE_CONTEXT_RETRIEVER = "codegate-context-retriever" + CODEGATE_SECRETS = "codegate-secrets" + + class GetWorkspaceByNameConditions(BaseModel): name: WorkspaceNameStr @@ -322,3 +337,18 @@ class PersonaDistance(Persona): """ distance: float + + +class GetMessagesRow(BaseModel): + id: Any + timestamp: Any + provider: Optional[Any] + request: Any + type: Any + output_id: Optional[Any] + output: Optional[Any] + output_timestamp: Optional[Any] + input_tokens: Optional[int] + output_tokens: Optional[int] + input_cost: Optional[float] + output_cost: Optional[float]