diff --git a/src/codegate/api/dashboard/dashboard.py b/src/codegate/api/dashboard/dashboard.py deleted file mode 100644 index fa558a91..00000000 --- a/src/codegate/api/dashboard/dashboard.py +++ /dev/null @@ -1,135 +0,0 @@ -import asyncio -from typing import AsyncGenerator, List, Optional - -import requests -import structlog -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse -from fastapi.routing import APIRoute - -from codegate import __version__ -from codegate.api.dashboard.post_processing import ( - parse_get_alert_conversation, - parse_messages_in_conversations, -) -from codegate.api.dashboard.request_models import AlertConversation, Conversation -from codegate.db.connection import DbReader, alert_queue -from codegate.workspaces import crud - -logger = structlog.get_logger("codegate") - -dashboard_router = APIRouter() -db_reader = None - -wscrud = crud.WorkspaceCrud() - - -def uniq_name(route: APIRoute): - return f"v1_{route.name}" - - -def get_db_reader(): - global db_reader - if db_reader is None: - db_reader = DbReader() - return db_reader - - -def fetch_latest_version() -> str: - url = "https://api.github.com/repos/stacklok/codegate/releases/latest" - headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} - response = requests.get(url, headers=headers, timeout=5) - response.raise_for_status() - data = response.json() - return data.get("tag_name", "unknown") - - -@dashboard_router.get( - "/dashboard/messages", tags=["Dashboard"], generate_unique_id_function=uniq_name -) -def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversation]: - """ - Get all the messages from the database and return them as a list of conversations. - """ - try: - active_ws = asyncio.run(wscrud.get_active_workspace()) - prompts_outputs = asyncio.run(db_reader.get_prompts_with_output(active_ws.id)) - - return asyncio.run(parse_messages_in_conversations(prompts_outputs)) - except Exception as e: - logger.error(f"Error getting messages: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") - - -@dashboard_router.get( - "/dashboard/alerts", tags=["Dashboard"], generate_unique_id_function=uniq_name -) -def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]: - """ - Get all the messages from the database and return them as a list of conversations. - """ - try: - active_ws = asyncio.run(wscrud.get_active_workspace()) - alerts_prompt_output = asyncio.run( - db_reader.get_alerts_with_prompt_and_output(active_ws.id) - ) - return asyncio.run(parse_get_alert_conversation(alerts_prompt_output)) - except Exception as e: - logger.error(f"Error getting alerts: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") - - -async def generate_sse_events() -> AsyncGenerator[str, None]: - """ - SSE generator from queue - """ - while True: - message = await alert_queue.get() - yield f"data: {message}\n\n" - - -@dashboard_router.get( - "/dashboard/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name -) -async def stream_sse(): - """ - Send alerts event - """ - return StreamingResponse(generate_sse_events(), media_type="text/event-stream") - - -@dashboard_router.get( - "/dashboard/version", tags=["Dashboard"], generate_unique_id_function=uniq_name -) -def version_check(): - try: - latest_version = fetch_latest_version() - - # normalize the versions as github will return them with a 'v' prefix - current_version = __version__.lstrip("v") - latest_version_stripped = latest_version.lstrip("v") - - is_latest: bool = latest_version_stripped == current_version - - return { - "current_version": current_version, - "latest_version": latest_version_stripped, - "is_latest": is_latest, - "error": None, - } - except requests.RequestException as e: - logger.error(f"RequestException: {str(e)}") - return { - "current_version": __version__, - "latest_version": "unknown", - "is_latest": None, - "error": "An error occurred while fetching the latest version", - } - except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - return { - "current_version": __version__, - "latest_version": "unknown", - "is_latest": None, - "error": "An unexpected error occurred", - } diff --git a/src/codegate/api/dashboard/request_models.py b/src/codegate/api/dashboard/request_models.py deleted file mode 100644 index d36d9391..00000000 --- a/src/codegate/api/dashboard/request_models.py +++ /dev/null @@ -1,72 +0,0 @@ -import datetime -from typing import List, Optional, Union - -from pydantic import BaseModel - -from codegate.pipeline.base import CodeSnippet - - -class ChatMessage(BaseModel): - """ - Represents a chat message. - """ - - message: str - timestamp: datetime.datetime - message_id: str - - -class QuestionAnswer(BaseModel): - """ - Represents a question and answer pair. - """ - - question: ChatMessage - answer: Optional[ChatMessage] - - -class PartialQuestions(BaseModel): - """ - Represents all user messages obtained from a DB row. - """ - - messages: List[str] - timestamp: datetime.datetime - message_id: str - provider: Optional[str] - type: str - - -class PartialQuestionAnswer(BaseModel): - """ - Represents a partial conversation. - """ - - partial_questions: PartialQuestions - answer: Optional[ChatMessage] - - -class Conversation(BaseModel): - """ - Represents a conversation. - """ - - question_answers: List[QuestionAnswer] - provider: Optional[str] - type: str - chat_id: str - conversation_timestamp: datetime.datetime - - -class AlertConversation(BaseModel): - """ - Represents an alert with it's respective conversation. - """ - - conversation: Conversation - alert_id: str - code_snippet: Optional[CodeSnippet] - trigger_string: Optional[Union[str, dict]] - trigger_type: str - trigger_category: Optional[str] - timestamp: datetime.datetime diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ff02dfb1..1e519698 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,17 +1,20 @@ from typing import List, Optional +import requests +import structlog from fastapi import APIRouter, HTTPException, Response +from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import ValidationError -from codegate.api import v1_models -from codegate.api.dashboard import dashboard -from codegate.api.dashboard.request_models import AlertConversation, Conversation +from codegate import __version__ +from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader from codegate.workspaces import crud +logger = structlog.get_logger("codegate") + v1 = APIRouter() -v1.include_router(dashboard.dashboard_router) wscrud = crud.WorkspaceCrud() # This is a singleton object @@ -192,7 +195,7 @@ async def hard_delete_workspace(workspace_name: str): tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_alerts(workspace_name: str) -> List[Optional[AlertConversation]]: +async def get_workspace_alerts(workspace_name: str) -> List[Optional[v1_models.AlertConversation]]: """Get alerts for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -203,7 +206,7 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[AlertConver try: alerts = await dbreader.get_alerts_with_prompt_and_output(ws.id) - return await dashboard.parse_get_alert_conversation(alerts) + return await v1_processing.parse_get_alert_conversation(alerts) except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -213,7 +216,7 @@ async def get_workspace_alerts(workspace_name: str) -> List[Optional[AlertConver tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def get_workspace_messages(workspace_name: str) -> List[Conversation]: +async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversation]: """Get messages for a workspace.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) @@ -224,7 +227,7 @@ async def get_workspace_messages(workspace_name: str) -> List[Conversation]: try: prompts_outputs = await dbreader.get_prompts_with_output(ws.id) - return await dashboard.parse_messages_in_conversations(prompts_outputs) + return await v1_processing.parse_messages_in_conversations(prompts_outputs) except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -285,3 +288,46 @@ async def delete_workspace_custom_instructions(workspace_name: str): raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) + + +@v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name) +async def stream_sse(): + """ + Send alerts event + """ + return StreamingResponse(v1_processing.generate_sse_events(), media_type="text/event-stream") + + +@v1.get("/version", tags=["Dashboard"], generate_unique_id_function=uniq_name) +def version_check(): + try: + latest_version = v1_processing.fetch_latest_version() + + # normalize the versions as github will return them with a 'v' prefix + current_version = __version__.lstrip("v") + latest_version_stripped = latest_version.lstrip("v") + + is_latest: bool = latest_version_stripped == current_version + + return { + "current_version": current_version, + "latest_version": latest_version_stripped, + "is_latest": is_latest, + "error": None, + } + except requests.RequestException as e: + logger.error(f"RequestException: {str(e)}") + return { + "current_version": __version__, + "latest_version": "unknown", + "is_latest": None, + "error": "An error occurred while fetching the latest version", + } + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + return { + "current_version": __version__, + "latest_version": "unknown", + "is_latest": None, + "error": "An unexpected error occurred", + } diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index ee86c208..9b23f74e 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -1,8 +1,10 @@ -from typing import Any, List, Optional +import datetime +from typing import Any, List, Optional, Union import pydantic from codegate.db import models as db_models +from codegate.pipeline.base import CodeSnippet class Workspace(pydantic.BaseModel): @@ -64,3 +66,69 @@ class CreateOrRenameWorkspaceRequest(pydantic.BaseModel): class ActivateWorkspaceRequest(pydantic.BaseModel): name: str + + +class ChatMessage(pydantic.BaseModel): + """ + Represents a chat message. + """ + + message: str + timestamp: datetime.datetime + message_id: str + + +class QuestionAnswer(pydantic.BaseModel): + """ + Represents a question and answer pair. + """ + + question: ChatMessage + answer: Optional[ChatMessage] + + +class PartialQuestions(pydantic.BaseModel): + """ + Represents all user messages obtained from a DB row. + """ + + messages: List[str] + timestamp: datetime.datetime + message_id: str + provider: Optional[str] + type: str + + +class PartialQuestionAnswer(pydantic.BaseModel): + """ + Represents a partial conversation. + """ + + partial_questions: PartialQuestions + answer: Optional[ChatMessage] + + +class Conversation(pydantic.BaseModel): + """ + Represents a conversation. + """ + + question_answers: List[QuestionAnswer] + provider: Optional[str] + type: str + chat_id: str + conversation_timestamp: datetime.datetime + + +class AlertConversation(pydantic.BaseModel): + """ + Represents an alert with it's respective conversation. + """ + + conversation: Conversation + alert_id: str + code_snippet: Optional[CodeSnippet] + trigger_string: Optional[Union[str, dict]] + trigger_type: str + trigger_category: Optional[str] + timestamp: datetime.datetime diff --git a/src/codegate/api/dashboard/post_processing.py b/src/codegate/api/v1_processing.py similarity index 94% rename from src/codegate/api/dashboard/post_processing.py rename to src/codegate/api/v1_processing.py index 1e4135d2..906584b2 100644 --- a/src/codegate/api/dashboard/post_processing.py +++ b/src/codegate/api/v1_processing.py @@ -2,11 +2,12 @@ import json import re from collections import defaultdict -from typing import List, Optional, Union +from typing import AsyncGenerator, List, Optional, Union +import requests import structlog -from codegate.api.dashboard.request_models import ( +from codegate.api.v1_models import ( AlertConversation, ChatMessage, Conversation, @@ -14,6 +15,7 @@ PartialQuestions, QuestionAnswer, ) +from codegate.db.connection import alert_queue from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow logger = structlog.get_logger("codegate") @@ -27,6 +29,24 @@ ] +def fetch_latest_version() -> str: + url = "https://api.github.com/repos/stacklok/codegate/releases/latest" + headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} + response = requests.get(url, headers=headers, timeout=5) + response.raise_for_status() + data = response.json() + return data.get("tag_name", "unknown") + + +async def generate_sse_events() -> AsyncGenerator[str, None]: + """ + SSE generator from queue + """ + while True: + message = await alert_queue.get() + yield f"data: {message}\n\n" + + async def _is_system_prompt(message: str) -> bool: """ Check if the message is a system prompt. diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index af42b551..8a00c68c 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,9 +1,8 @@ import json -from fastapi.responses import JSONResponse -import httpx import structlog from fastapi import Header, HTTPException, Request +from fastapi.responses import JSONResponse from codegate.config import Config from codegate.pipeline.factory import PipelineFactory diff --git a/tests/dashboard/test_post_processing.py b/tests/api/test_v1_processing.py similarity index 98% rename from tests/dashboard/test_post_processing.py rename to tests/api/test_v1_processing.py index d6359efb..20598c67 100644 --- a/tests/dashboard/test_post_processing.py +++ b/tests/api/test_v1_processing.py @@ -4,16 +4,16 @@ import pytest -from codegate.api.dashboard.post_processing import ( +from codegate.api.v1_models import ( + PartialQuestions, +) +from codegate.api.v1_processing import ( _get_question_answer, _group_partial_messages, _is_system_prompt, parse_output, parse_request, ) -from codegate.api.dashboard.request_models import ( - PartialQuestions, -) from codegate.db.models import GetPromptWithOutputsRow @@ -162,10 +162,10 @@ async def test_parse_output(output_dict, expected_str): ) async def test_get_question_answer(request_msg_list, output_msg_str, row): with patch( - "codegate.api.dashboard.post_processing.parse_request", new_callable=AsyncMock + "codegate.api.v1_processing.parse_request", new_callable=AsyncMock ) as mock_parse_request: with patch( - "codegate.api.dashboard.post_processing.parse_output", new_callable=AsyncMock + "codegate.api.v1_processing.parse_output", new_callable=AsyncMock ) as mock_parse_output: # Set return values for the mocks mock_parse_request.return_value = request_msg_list diff --git a/tests/test_server.py b/tests/test_server.py index 8e07c0ee..f7b7a12f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -82,10 +82,10 @@ def test_health_check(test_client: TestClient) -> None: assert response.json() == {"status": "healthy"} -@patch("codegate.api.dashboard.dashboard.fetch_latest_version", return_value="foo") +@patch("codegate.api.v1_processing.fetch_latest_version", return_value="foo") def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> None: """Test the version endpoint.""" - response = test_client.get("/api/v1/dashboard/version") + response = test_client.get("/api/v1/version") assert response.status_code == 200 response_data = response.json() @@ -135,13 +135,13 @@ def test_pipeline_initialization(mock_pipeline_factory) -> None: assert hasattr(provider, "output_pipeline_processor") -def test_dashboard_routes(mock_pipeline_factory) -> None: +def test_workspaces_routes(mock_pipeline_factory) -> None: """Test that dashboard routes are included.""" app = init_app(mock_pipeline_factory) routes = [route.path for route in app.routes] # Verify dashboard endpoints are included - dashboard_routes = [route for route in routes if route.startswith("/api/v1/dashboard")] + dashboard_routes = [route for route in routes if route.startswith("/api/v1/workspaces")] assert len(dashboard_routes) > 0