1
1
import asyncio
2
2
import json
3
+ import uuid
3
4
from pathlib import Path
4
5
from typing import List , Optional , Type
5
6
6
7
import structlog
7
8
from alembic import command as alembic_command
8
9
from alembic .config import Config as AlembicConfig
9
10
from pydantic import BaseModel
10
- from sqlalchemy import TextClause , text
11
+ from sqlalchemy import CursorResult , TextClause , text
11
12
from sqlalchemy .exc import OperationalError
12
13
from sqlalchemy .ext .asyncio import create_async_engine
13
14
14
15
from codegate .db .fim_cache import FimCache
15
16
from codegate .db .models import (
17
+ ActiveWorkspace ,
16
18
Alert ,
17
19
GetAlertsWithPromptAndOutputRow ,
18
20
GetPromptWithOutputsRow ,
19
21
Output ,
20
22
Prompt ,
23
+ Session ,
21
24
Workspace ,
25
+ WorkspaceActive ,
22
26
)
23
27
from codegate .pipeline .base import PipelineContext
24
28
@@ -76,10 +80,14 @@ async def _execute_update_pydantic_model(
76
80
async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
77
81
if prompt_params is None :
78
82
return None
83
+ # Get the active workspace to store the request
84
+ active_workspace = await DbReader ().get_active_workspace ()
85
+ workspace_id = active_workspace .id if active_workspace else "1"
86
+ prompt_params .workspace_id = workspace_id
79
87
sql = text (
80
88
"""
81
- INSERT INTO prompts (id, timestamp, provider, request, type)
82
- VALUES (:id, :timestamp, :provider, :request, :type)
89
+ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id )
90
+ VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id )
83
91
RETURNING *
84
92
"""
85
93
)
@@ -224,26 +232,73 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
224
232
except Exception as e :
225
233
logger .error (f"Failed to record context: { context } ." , error = str (e ))
226
234
235
+ async def add_workspace (self , workspace_name : str ) -> Optional [Workspace ]:
236
+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
237
+ sql = text (
238
+ """
239
+ INSERT INTO workspaces (id, name)
240
+ VALUES (:id, :name)
241
+ RETURNING *
242
+ """
243
+ )
244
+ added_workspace = await self ._execute_update_pydantic_model (workspace , sql )
245
+ return added_workspace
246
+
247
+ async def update_session (self , session : Session ) -> Optional [Session ]:
248
+ sql = text (
249
+ """
250
+ INSERT INTO sessions (id, active_workspace_id, last_update)
251
+ VALUES (:id, :active_workspace_id, :last_update)
252
+ ON CONFLICT (id) DO UPDATE SET
253
+ active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
254
+ WHERE id = excluded.id
255
+ RETURNING *
256
+ """
257
+ )
258
+ # We only pass an object to respect the signature of the function
259
+ active_session = await self ._execute_update_pydantic_model (session , sql )
260
+ return active_session
261
+
227
262
228
263
class DbReader (DbCodeGate ):
229
264
230
265
def __init__ (self , sqlite_path : Optional [str ] = None ):
231
266
super ().__init__ (sqlite_path )
232
267
268
+ async def _dump_result_to_pydantic_model (
269
+ self , model_type : Type [BaseModel ], result : CursorResult
270
+ ) -> Optional [List [BaseModel ]]:
271
+ try :
272
+ if not result :
273
+ return None
274
+ rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
275
+ return rows
276
+ except Exception as e :
277
+ logger .error (f"Failed to dump to pydantic model: { model_type } ." , error = str (e ))
278
+ return None
279
+
233
280
async def _execute_select_pydantic_model (
234
281
self , model_type : Type [BaseModel ], sql_command : TextClause
235
- ) -> Optional [BaseModel ]:
282
+ ) -> Optional [List [ BaseModel ] ]:
236
283
async with self ._async_db_engine .begin () as conn :
237
284
try :
238
285
result = await conn .execute (sql_command )
239
- if not result :
240
- return None
241
- rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
242
- return rows
286
+ return await self ._dump_result_to_pydantic_model (model_type , result )
243
287
except Exception as e :
244
288
logger .error (f"Failed to select model: { model_type } ." , error = str (e ))
245
289
return None
246
290
291
+ async def _exec_select_conditions_to_pydantic (
292
+ self , model_type : Type [BaseModel ], sql_command : TextClause , conditions : dict
293
+ ) -> Optional [List [BaseModel ]]:
294
+ async with self ._async_db_engine .begin () as conn :
295
+ try :
296
+ result = await conn .execute (sql_command , conditions )
297
+ return await self ._dump_result_to_pydantic_model (model_type , result )
298
+ except Exception as e :
299
+ logger .error (f"Failed to select model with conditions: { model_type } ." , error = str (e ))
300
+ return None
301
+
247
302
async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
248
303
sql = text (
249
304
"""
@@ -287,18 +342,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
287
342
prompts = await self ._execute_select_pydantic_model (GetAlertsWithPromptAndOutputRow , sql )
288
343
return prompts
289
344
290
- async def get_workspaces (self ) -> List [Workspace ]:
345
+ async def get_workspaces (self ) -> List [WorkspaceActive ]:
291
346
sql = text (
292
347
"""
293
348
SELECT
294
- id, name, is_active
349
+ w.id, w.name, s.active_workspace_id
350
+ FROM workspaces w
351
+ LEFT JOIN sessions s ON w.id = s.active_workspace_id
352
+ """
353
+ )
354
+ workspaces = await self ._execute_select_pydantic_model (WorkspaceActive , sql )
355
+ return workspaces
356
+
357
+ async def get_workspace_by_name (self , name : str ) -> List [Workspace ]:
358
+ sql = text (
359
+ """
360
+ SELECT
361
+ id, name
295
362
FROM workspaces
296
- ORDER BY is_active DESC
363
+ WHERE name = :name
297
364
"""
298
365
)
299
- workspaces = await self ._execute_select_pydantic_model (Workspace , sql )
366
+ conditions = {"name" : name }
367
+ workspaces = await self ._exec_select_conditions_to_pydantic (Workspace , sql , conditions )
300
368
return workspaces
301
369
370
+ async def get_sessions (self ) -> List [Session ]:
371
+ sql = text (
372
+ """
373
+ SELECT
374
+ id, active_workspace_id, last_update
375
+ FROM sessions
376
+ """
377
+ )
378
+ sessions = await self ._execute_select_pydantic_model (Session , sql )
379
+ return sessions
380
+
381
+ async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
382
+ sql = text (
383
+ """
384
+ SELECT
385
+ w.id, w.name, s.id as session_id, s.last_update
386
+ FROM sessions s
387
+ INNER JOIN workspaces w ON w.id = s.active_workspace_id
388
+ """
389
+ )
390
+ active_workspace = await self ._execute_select_pydantic_model (ActiveWorkspace , sql )
391
+ return active_workspace [0 ] if active_workspace else None
392
+
302
393
303
394
def init_db_sync (db_path : Optional [str ] = None ):
304
395
"""DB will be initialized in the constructor in case it doesn't exist."""
@@ -320,5 +411,22 @@ def init_db_sync(db_path: Optional[str] = None):
320
411
logger .info ("DB initialized successfully." )
321
412
322
413
414
+ def init_session_if_not_exists (db_path : Optional [str ] = None ):
415
+ import datetime
416
+ db_reader = DbReader (db_path )
417
+ sessions = asyncio .run (db_reader .get_sessions ())
418
+ # If there are no sessions, create a new one
419
+ # TODO: For the moment there's a single session. If it already exists, we don't create a new one
420
+ if not sessions :
421
+ session = Session (
422
+ id = str (uuid .uuid4 ()),
423
+ active_workspace_id = "1" ,
424
+ last_update = datetime .datetime .now (datetime .timezone .utc )
425
+ )
426
+ db_recorder = DbRecorder (db_path )
427
+ asyncio .run (db_recorder .update_session (session ))
428
+ logger .info ("Session in DB initialized successfully." )
429
+
430
+
323
431
if __name__ == "__main__" :
324
432
init_db_sync ()
0 commit comments