Skip to content

Commit 6f1c412

Browse files
committed
Use exceptions for handling workspace add error
This stops using the boolean and instead will raise exceptions if there's an issue adding a workspace. This will help us differentiate if the operation failed due to a name already being taken, or the name having invalid characters. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent b68186c commit 6f1c412

File tree

5 files changed

+60
-38
lines changed

5 files changed

+60
-38
lines changed

src/codegate/api/v1.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from fastapi import APIRouter, Response
22
from fastapi.exceptions import HTTPException
33
from fastapi.routing import APIRoute
4+
from pydantic import ValidationError
45

56
from codegate.api import v1_models
7+
from codegate.db.connection import AlreadyExistsError
68
from codegate.workspaces.crud import WorkspaceCrud
79

810
v1 = APIRouter()
@@ -52,13 +54,17 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status
5254
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
5355
"""Create a new workspace."""
5456
# Input validation is done in the model
55-
created = await wscrud.add_workspace(request.name)
56-
57-
# TODO: refactor to use a more specific exception
58-
if not created:
59-
raise HTTPException(status_code=400, detail="Failed to create workspace")
60-
61-
return v1_models.Workspace(name=request.name)
57+
try:
58+
created = await wscrud.add_workspace(request.name)
59+
except AlreadyExistsError:
60+
raise HTTPException(status_code=409, detail="Workspace already exists")
61+
except ValidationError as e:
62+
raise HTTPException(status_code=400, detail=str(e))
63+
except Exception:
64+
raise HTTPException(status_code=500, detail="Internal server error")
65+
66+
if created:
67+
return v1_models.Workspace(name=created.name)
6268

6369

6470
@v1.delete(

src/codegate/dashboard/dashboard.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import structlog
77
from fastapi import APIRouter, Depends, FastAPI
88
from fastapi.responses import StreamingResponse
9-
from codegate import __version__
109

10+
from codegate import __version__
1111
from codegate.dashboard.post_processing import (
1212
parse_get_alert_conversation,
1313
parse_messages_in_conversations,
@@ -82,7 +82,7 @@ def version_check():
8282
latest_version_stripped = latest_version.lstrip('v')
8383

8484
is_latest: bool = latest_version_stripped == current_version
85-
85+
8686
return {
8787
"current_version": current_version,
8888
"latest_version": latest_version_stripped,

src/codegate/db/connection.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import structlog
88
from alembic import command as alembic_command
99
from alembic.config import Config as AlembicConfig
10-
from pydantic import BaseModel, ValidationError
10+
from pydantic import BaseModel
1111
from sqlalchemy import CursorResult, TextClause, text
12-
from sqlalchemy.exc import OperationalError
12+
from sqlalchemy.exc import IntegrityError, OperationalError
1313
from sqlalchemy.ext.asyncio import create_async_engine
1414

1515
from codegate.db.fim_cache import FimCache
@@ -30,6 +30,8 @@
3030
alert_queue = asyncio.Queue()
3131
fim_cache = FimCache()
3232

33+
class AlreadyExistsError(Exception):
34+
pass
3335

3436
class DbCodeGate:
3537
_instance = None
@@ -70,11 +72,11 @@ def __init__(self, sqlite_path: Optional[str] = None):
7072
super().__init__(sqlite_path)
7173

7274
async def _execute_update_pydantic_model(
73-
self, model: BaseModel, sql_command: TextClause
75+
self, model: BaseModel, sql_command: TextClause, should_raise: bool = False
7476
) -> Optional[BaseModel]:
7577
"""Execute an update or insert command for a Pydantic model."""
76-
async with self._async_db_engine.begin() as conn:
77-
try:
78+
try:
79+
async with self._async_db_engine.begin() as conn:
7880
result = await conn.execute(sql_command, model.model_dump())
7981
row = result.first()
8082
if row is None:
@@ -83,9 +85,11 @@ async def _execute_update_pydantic_model(
8385
# Get the class of the Pydantic object to create a new object
8486
model_class = model.__class__
8587
return model_class(**row._asdict())
86-
except Exception as e:
87-
logger.error(f"Failed to update model: {model}.", error=str(e))
88-
return None
88+
except Exception as e:
89+
logger.error(f"Failed to update model: {model}.", error=str(e))
90+
if should_raise:
91+
raise e
92+
return None
8993

9094
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
9195
if prompt_params is None:
@@ -243,11 +247,14 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
243247
logger.error(f"Failed to record context: {context}.", error=str(e))
244248

245249
async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
246-
try:
247-
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
248-
except ValidationError as e:
249-
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
250-
return None
250+
"""Add a new workspace to the DB.
251+
252+
This handles validation and insertion of a new workspace.
253+
254+
It may raise a ValidationError if the workspace name is invalid.
255+
or a AlreadyExistsError if the workspace already exists.
256+
"""
257+
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
251258

252259
sql = text(
253260
"""
@@ -256,12 +263,13 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
256263
RETURNING *
257264
"""
258265
)
259-
try:
260-
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
261-
except Exception as e:
262-
logger.error(f"Failed to add workspace: {workspace_name}.", error=str(e))
263-
return None
264266

267+
try:
268+
added_workspace = await self._execute_update_pydantic_model(
269+
workspace, sql, should_raise=True)
270+
except IntegrityError as e:
271+
logger.debug(f"Exception type: {type(e)}")
272+
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
265273
return added_workspace
266274

267275
async def update_session(self, session: Session) -> Optional[Session]:

src/codegate/pipeline/cli/commands.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from abc import ABC, abstractmethod
22
from typing import List
33

4+
from pydantic import ValidationError
5+
46
from codegate import __version__
7+
from codegate.db.connection import AlreadyExistsError
58
from codegate.workspaces.crud import WorkspaceCrud
69

710

@@ -69,13 +72,15 @@ async def _add_workspace(self, args: List[str]) -> str:
6972
if not new_workspace_name:
7073
return "Please provide a name. Use `codegate workspace add your_workspace_name`"
7174

72-
workspace_created = await self.workspace_crud.add_workspace(new_workspace_name)
73-
if not workspace_created:
74-
return (
75-
"Something went wrong. Workspace could not be added.\n"
76-
"1. Check if the name is alphanumeric and only contains dashes, and underscores.\n"
77-
"2. Check if the workspace already exists."
78-
)
75+
try:
76+
_ = await self.workspace_crud.add_workspace(new_workspace_name)
77+
except ValidationError as e:
78+
return f"Invalid workspace name: {e}"
79+
except AlreadyExistsError:
80+
return f"Workspace **{new_workspace_name}** already exists"
81+
except Exception:
82+
return "An error occurred while adding the workspace"
83+
7984
return f"Workspace **{new_workspace_name}** has been added"
8085

8186
async def _activate_workspace(self, args: List[str]) -> str:

src/codegate/workspaces/crud.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import datetime
2-
from typing import Optional, Tuple, List
2+
from typing import List, Optional, Tuple
33

44
from codegate.db.connection import DbReader, DbRecorder
5-
from codegate.db.models import Session, Workspace, WorkspaceActive, ActiveWorkspace
5+
from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive
66

77

8+
class WorkspaceCrudError(Exception):
9+
pass
10+
811
class WorkspaceCrud:
912

1013
def __init__(self):
1114
self._db_reader = DbReader()
1215

13-
async def add_workspace(self, new_workspace_name: str) -> bool:
16+
async def add_workspace(self, new_workspace_name: str) -> Workspace:
1417
"""
1518
Add a workspace
1619
@@ -19,7 +22,7 @@ async def add_workspace(self, new_workspace_name: str) -> bool:
1922
"""
2023
db_recorder = DbRecorder()
2124
workspace_created = await db_recorder.add_workspace(new_workspace_name)
22-
return bool(workspace_created)
25+
return workspace_created
2326

2427
async def get_workspaces(self)-> List[WorkspaceActive]:
2528
"""

0 commit comments

Comments
 (0)