1
1
import asyncio
2
2
import json
3
+ import uuid
3
4
from pathlib import Path
4
5
from typing import List , Optional , Type
5
6
6
7
import structlog
7
8
from alembic import command as alembic_command
8
9
from alembic .config import Config as AlembicConfig
9
- from pydantic import BaseModel
10
- from sqlalchemy import TextClause , text
10
+ from pydantic import BaseModel , ValidationError
11
+ from sqlalchemy import CursorResult , TextClause , text
11
12
from sqlalchemy .exc import OperationalError
12
13
from sqlalchemy .ext .asyncio import create_async_engine
13
14
14
15
from codegate .db .fim_cache import FimCache
15
16
from codegate .db .models import (
17
+ ActiveWorkspace ,
16
18
Alert ,
17
19
GetAlertsWithPromptAndOutputRow ,
18
20
GetPromptWithOutputsRow ,
19
21
Output ,
20
22
Prompt ,
23
+ Session ,
24
+ Workspace ,
25
+ WorkspaceActive ,
21
26
)
22
27
from codegate .pipeline .base import PipelineContext
23
28
@@ -75,10 +80,14 @@ async def _execute_update_pydantic_model(
75
80
async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
76
81
if prompt_params is None :
77
82
return None
83
+ # Get the active workspace to store the request
84
+ active_workspace = await DbReader ().get_active_workspace ()
85
+ workspace_id = active_workspace .id if active_workspace else "1"
86
+ prompt_params .workspace_id = workspace_id
78
87
sql = text (
79
88
"""
80
- INSERT INTO prompts (id, timestamp, provider, request, type)
81
- VALUES (:id, :timestamp, :provider, :request, :type)
89
+ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id )
90
+ VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id )
82
91
RETURNING *
83
92
"""
84
93
)
@@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
223
232
except Exception as e :
224
233
logger .error (f"Failed to record context: { context } ." , error = str (e ))
225
234
235
+ async def add_workspace (self , workspace_name : str ) -> Optional [Workspace ]:
236
+ try :
237
+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
238
+ except ValidationError as e :
239
+ logger .error (f"Failed to create workspace with name: { workspace_name } : { str (e )} " )
240
+ return None
241
+
242
+ sql = text (
243
+ """
244
+ INSERT INTO workspaces (id, name)
245
+ VALUES (:id, :name)
246
+ RETURNING *
247
+ """
248
+ )
249
+ added_workspace = await self ._execute_update_pydantic_model (workspace , sql )
250
+ return added_workspace
251
+
252
+ async def update_session (self , session : Session ) -> Optional [Session ]:
253
+ sql = text (
254
+ """
255
+ INSERT INTO sessions (id, active_workspace_id, last_update)
256
+ VALUES (:id, :active_workspace_id, :last_update)
257
+ ON CONFLICT (id) DO UPDATE SET
258
+ active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
259
+ WHERE id = excluded.id
260
+ RETURNING *
261
+ """
262
+ )
263
+ # We only pass an object to respect the signature of the function
264
+ active_session = await self ._execute_update_pydantic_model (session , sql )
265
+ return active_session
266
+
226
267
227
268
class DbReader (DbCodeGate ):
228
269
229
270
def __init__ (self , sqlite_path : Optional [str ] = None ):
230
271
super ().__init__ (sqlite_path )
231
272
273
+ async def _dump_result_to_pydantic_model (
274
+ self , model_type : Type [BaseModel ], result : CursorResult
275
+ ) -> Optional [List [BaseModel ]]:
276
+ try :
277
+ if not result :
278
+ return None
279
+ rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
280
+ return rows
281
+ except Exception as e :
282
+ logger .error (f"Failed to dump to pydantic model: { model_type } ." , error = str (e ))
283
+ return None
284
+
232
285
async def _execute_select_pydantic_model (
233
286
self , model_type : Type [BaseModel ], sql_command : TextClause
234
- ) -> Optional [BaseModel ]:
287
+ ) -> Optional [List [ BaseModel ] ]:
235
288
async with self ._async_db_engine .begin () as conn :
236
289
try :
237
290
result = await conn .execute (sql_command )
238
- if not result :
239
- return None
240
- rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
241
- return rows
291
+ return await self ._dump_result_to_pydantic_model (model_type , result )
242
292
except Exception as e :
243
293
logger .error (f"Failed to select model: { model_type } ." , error = str (e ))
244
294
return None
245
295
296
+ async def _exec_select_conditions_to_pydantic (
297
+ self , model_type : Type [BaseModel ], sql_command : TextClause , conditions : dict
298
+ ) -> Optional [List [BaseModel ]]:
299
+ async with self ._async_db_engine .begin () as conn :
300
+ try :
301
+ result = await conn .execute (sql_command , conditions )
302
+ return await self ._dump_result_to_pydantic_model (model_type , result )
303
+ except Exception as e :
304
+ logger .error (f"Failed to select model with conditions: { model_type } ." , error = str (e ))
305
+ return None
306
+
246
307
async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
247
308
sql = text (
248
309
"""
@@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
286
347
prompts = await self ._execute_select_pydantic_model (GetAlertsWithPromptAndOutputRow , sql )
287
348
return prompts
288
349
350
+ async def get_workspaces (self ) -> List [WorkspaceActive ]:
351
+ sql = text (
352
+ """
353
+ SELECT
354
+ w.id, w.name, s.active_workspace_id
355
+ FROM workspaces w
356
+ LEFT JOIN sessions s ON w.id = s.active_workspace_id
357
+ """
358
+ )
359
+ workspaces = await self ._execute_select_pydantic_model (WorkspaceActive , sql )
360
+ return workspaces
361
+
362
+ async def get_workspace_by_name (self , name : str ) -> List [Workspace ]:
363
+ sql = text (
364
+ """
365
+ SELECT
366
+ id, name
367
+ FROM workspaces
368
+ WHERE name = :name
369
+ """
370
+ )
371
+ conditions = {"name" : name }
372
+ workspaces = await self ._exec_select_conditions_to_pydantic (Workspace , sql , conditions )
373
+ return workspaces [0 ] if workspaces else None
374
+
375
+ async def get_sessions (self ) -> List [Session ]:
376
+ sql = text (
377
+ """
378
+ SELECT
379
+ id, active_workspace_id, last_update
380
+ FROM sessions
381
+ """
382
+ )
383
+ sessions = await self ._execute_select_pydantic_model (Session , sql )
384
+ return sessions
385
+
386
+ async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
387
+ sql = text (
388
+ """
389
+ SELECT
390
+ w.id, w.name, s.id as session_id, s.last_update
391
+ FROM sessions s
392
+ INNER JOIN workspaces w ON w.id = s.active_workspace_id
393
+ """
394
+ )
395
+ active_workspace = await self ._execute_select_pydantic_model (ActiveWorkspace , sql )
396
+ return active_workspace [0 ] if active_workspace else None
397
+
289
398
290
399
def init_db_sync (db_path : Optional [str ] = None ):
291
400
"""DB will be initialized in the constructor in case it doesn't exist."""
@@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None):
307
416
logger .info ("DB initialized successfully." )
308
417
309
418
419
+ def init_session_if_not_exists (db_path : Optional [str ] = None ):
420
+ import datetime
421
+
422
+ db_reader = DbReader (db_path )
423
+ sessions = asyncio .run (db_reader .get_sessions ())
424
+ # If there are no sessions, create a new one
425
+ # TODO: For the moment there's a single session. If it already exists, we don't create a new one
426
+ if not sessions :
427
+ session = Session (
428
+ id = str (uuid .uuid4 ()),
429
+ active_workspace_id = "1" ,
430
+ last_update = datetime .datetime .now (datetime .timezone .utc ),
431
+ )
432
+ db_recorder = DbRecorder (db_path )
433
+ asyncio .run (db_recorder .update_session (session ))
434
+ logger .info ("Session in DB initialized successfully." )
435
+
436
+
310
437
if __name__ == "__main__" :
311
438
init_db_sync ()
0 commit comments