7
7
import structlog
8
8
from alembic import command as alembic_command
9
9
from alembic .config import Config as AlembicConfig
10
- from pydantic import BaseModel , ValidationError
10
+ from pydantic import BaseModel
11
11
from sqlalchemy import CursorResult , TextClause , text
12
- from sqlalchemy .exc import OperationalError
12
+ from sqlalchemy .exc import IntegrityError , OperationalError
13
13
from sqlalchemy .ext .asyncio import create_async_engine
14
14
15
15
from codegate .db .fim_cache import FimCache
30
30
alert_queue = asyncio .Queue ()
31
31
fim_cache = FimCache ()
32
32
33
+ class AlreadyExistsError (Exception ):
34
+ pass
33
35
34
36
class DbCodeGate :
35
37
_instance = None
@@ -70,11 +72,11 @@ def __init__(self, sqlite_path: Optional[str] = None):
70
72
super ().__init__ (sqlite_path )
71
73
72
74
async def _execute_update_pydantic_model (
73
- self , model : BaseModel , sql_command : TextClause
75
+ self , model : BaseModel , sql_command : TextClause , should_raise : bool = False
74
76
) -> Optional [BaseModel ]:
75
77
"""Execute an update or insert command for a Pydantic model."""
76
- async with self . _async_db_engine . begin () as conn :
77
- try :
78
+ try :
79
+ async with self . _async_db_engine . begin () as conn :
78
80
result = await conn .execute (sql_command , model .model_dump ())
79
81
row = result .first ()
80
82
if row is None :
@@ -83,9 +85,11 @@ async def _execute_update_pydantic_model(
83
85
# Get the class of the Pydantic object to create a new object
84
86
model_class = model .__class__
85
87
return model_class (** row ._asdict ())
86
- except Exception as e :
87
- logger .error (f"Failed to update model: { model } ." , error = str (e ))
88
- return None
88
+ except Exception as e :
89
+ logger .error (f"Failed to update model: { model } ." , error = str (e ))
90
+ if should_raise :
91
+ raise e
92
+ return None
89
93
90
94
async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
91
95
if prompt_params is None :
@@ -243,11 +247,14 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
243
247
logger .error (f"Failed to record context: { context } ." , error = str (e ))
244
248
245
249
async def add_workspace (self , workspace_name : str ) -> Optional [Workspace ]:
246
- try :
247
- workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
248
- except ValidationError as e :
249
- logger .error (f"Failed to create workspace with name: { workspace_name } : { str (e )} " )
250
- return None
250
+ """Add a new workspace to the DB.
251
+
252
+ This handles validation and insertion of a new workspace.
253
+
254
+ It may raise a ValidationError if the workspace name is invalid.
255
+ or a AlreadyExistsError if the workspace already exists.
256
+ """
257
+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
251
258
252
259
sql = text (
253
260
"""
@@ -256,12 +263,13 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
256
263
RETURNING *
257
264
"""
258
265
)
259
- try :
260
- added_workspace = await self ._execute_update_pydantic_model (workspace , sql )
261
- except Exception as e :
262
- logger .error (f"Failed to add workspace: { workspace_name } ." , error = str (e ))
263
- return None
264
266
267
+ try :
268
+ added_workspace = await self ._execute_update_pydantic_model (
269
+ workspace , sql , should_raise = True )
270
+ except IntegrityError as e :
271
+ logger .debug (f"Exception type: { type (e )} " )
272
+ raise AlreadyExistsError (f"Workspace { workspace_name } already exists." )
265
273
return added_workspace
266
274
267
275
async def update_session (self , session : Session ) -> Optional [Session ]:
0 commit comments