Skip to content

Commit 370231e

Browse files
Introduced Sessions table and added add and activate commands
1 parent 2f6dfd9 commit 370231e

File tree

6 files changed

+223
-24
lines changed

6 files changed

+223
-24
lines changed

migrations/versions/5c2f3eee5f90_introduce_workspaces.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,38 @@ def upgrade() -> None:
2424
CREATE TABLE workspaces (
2525
id TEXT PRIMARY KEY, -- UUID stored as TEXT
2626
name TEXT NOT NULL,
27-
is_active BOOLEAN NOT NULL DEFAULT 0,
2827
UNIQUE (name)
2928
);
3029
"""
3130
)
32-
op.execute("INSERT INTO workspaces (id, name, is_active) VALUES ('1', 'default', 1);")
31+
op.execute("INSERT INTO workspaces (id, name) VALUES ('1', 'default');")
32+
# Sessions table
33+
op.execute(
34+
"""
35+
CREATE TABLE sessions (
36+
id TEXT PRIMARY KEY, -- UUID stored as TEXT
37+
active_workspace_id TEXT NOT NULL,
38+
last_update DATETIME NOT NULL,
39+
FOREIGN KEY (active_workspace_id) REFERENCES workspaces(id)
40+
);
41+
"""
42+
)
3343
# Alter table prompts
3444
op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);")
3545
op.execute("UPDATE prompts SET workspace_id = '1';")
3646
# Create index for workspace_id
3747
op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);")
48+
# Create index for session_id
49+
op.execute("CREATE INDEX idx_sessions_workspace_id ON sessions (active_workspace_id);")
3850

3951

4052
def downgrade() -> None:
4153
# Drop the index for workspace_id
4254
op.execute("DROP INDEX IF EXISTS idx_prompts_workspace_id;")
55+
op.execute("DROP INDEX IF EXISTS idx_sessions_workspace_id;")
4356
# Remove the workspace_id column from prompts table
4457
op.execute("ALTER TABLE prompts DROP COLUMN workspace_id;")
58+
# Drop the sessions table
59+
op.execute("DROP TABLE IF EXISTS sessions;")
4560
# Drop the workspaces table
4661
op.execute("DROP TABLE IF EXISTS workspaces;")

src/codegate/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from codegate.ca.codegate_ca import CertificateAuthority
1515
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
1616
from codegate.config import Config, ConfigurationError
17-
from codegate.db.connection import init_db_sync
17+
from codegate.db.connection import init_db_sync, init_session_if_not_exists
1818
from codegate.pipeline.factory import PipelineFactory
1919
from codegate.pipeline.secrets.manager import SecretsManager
2020
from codegate.providers.copilot.provider import CopilotProvider
@@ -307,6 +307,7 @@ def serve(
307307
logger = structlog.get_logger("codegate").bind(origin="cli")
308308

309309
init_db_sync(cfg.db_path)
310+
init_session_if_not_exists(cfg.db_path)
310311

311312
# Check certificates and create CA if necessary
312313
logger.info("Checking certificates and creating CA if needed")

src/codegate/db/connection.py

+120-12
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
import asyncio
22
import json
3+
import uuid
34
from pathlib import Path
45
from typing import List, Optional, Type
56

67
import structlog
78
from alembic import command as alembic_command
89
from alembic.config import Config as AlembicConfig
910
from pydantic import BaseModel
10-
from sqlalchemy import TextClause, text
11+
from sqlalchemy import CursorResult, TextClause, text
1112
from sqlalchemy.exc import OperationalError
1213
from sqlalchemy.ext.asyncio import create_async_engine
1314

1415
from codegate.db.fim_cache import FimCache
1516
from codegate.db.models import (
17+
ActiveWorkspace,
1618
Alert,
1719
GetAlertsWithPromptAndOutputRow,
1820
GetPromptWithOutputsRow,
1921
Output,
2022
Prompt,
23+
Session,
2124
Workspace,
25+
WorkspaceActive,
2226
)
2327
from codegate.pipeline.base import PipelineContext
2428

@@ -76,10 +80,14 @@ async def _execute_update_pydantic_model(
7680
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
7781
if prompt_params is None:
7882
return None
83+
# Get the active workspace to store the request
84+
active_workspace = await DbReader().get_active_workspace()
85+
workspace_id = active_workspace.id if active_workspace else "1"
86+
prompt_params.workspace_id = workspace_id
7987
sql = text(
8088
"""
81-
INSERT INTO prompts (id, timestamp, provider, request, type)
82-
VALUES (:id, :timestamp, :provider, :request, :type)
89+
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
90+
VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id)
8391
RETURNING *
8492
"""
8593
)
@@ -224,26 +232,73 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
224232
except Exception as e:
225233
logger.error(f"Failed to record context: {context}.", error=str(e))
226234

235+
async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
236+
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
237+
sql = text(
238+
"""
239+
INSERT INTO workspaces (id, name)
240+
VALUES (:id, :name)
241+
RETURNING *
242+
"""
243+
)
244+
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
245+
return added_workspace
246+
247+
async def update_session(self, session: Session) -> Optional[Session]:
248+
sql = text(
249+
"""
250+
INSERT INTO sessions (id, active_workspace_id, last_update)
251+
VALUES (:id, :active_workspace_id, :last_update)
252+
ON CONFLICT (id) DO UPDATE SET
253+
active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
254+
WHERE id = excluded.id
255+
RETURNING *
256+
"""
257+
)
258+
# We only pass an object to respect the signature of the function
259+
active_session = await self._execute_update_pydantic_model(session, sql)
260+
return active_session
261+
227262

228263
class DbReader(DbCodeGate):
229264

230265
def __init__(self, sqlite_path: Optional[str] = None):
231266
super().__init__(sqlite_path)
232267

268+
async def _dump_result_to_pydantic_model(
269+
self, model_type: Type[BaseModel], result: CursorResult
270+
) -> Optional[List[BaseModel]]:
271+
try:
272+
if not result:
273+
return None
274+
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
275+
return rows
276+
except Exception as e:
277+
logger.error(f"Failed to dump to pydantic model: {model_type}.", error=str(e))
278+
return None
279+
233280
async def _execute_select_pydantic_model(
234281
self, model_type: Type[BaseModel], sql_command: TextClause
235-
) -> Optional[BaseModel]:
282+
) -> Optional[List[BaseModel]]:
236283
async with self._async_db_engine.begin() as conn:
237284
try:
238285
result = await conn.execute(sql_command)
239-
if not result:
240-
return None
241-
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
242-
return rows
286+
return await self._dump_result_to_pydantic_model(model_type, result)
243287
except Exception as e:
244288
logger.error(f"Failed to select model: {model_type}.", error=str(e))
245289
return None
246290

291+
async def _exec_select_conditions_to_pydantic(
292+
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
293+
) -> Optional[List[BaseModel]]:
294+
async with self._async_db_engine.begin() as conn:
295+
try:
296+
result = await conn.execute(sql_command, conditions)
297+
return await self._dump_result_to_pydantic_model(model_type, result)
298+
except Exception as e:
299+
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
300+
return None
301+
247302
async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
248303
sql = text(
249304
"""
@@ -287,18 +342,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
287342
prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql)
288343
return prompts
289344

290-
async def get_workspaces(self) -> List[Workspace]:
345+
async def get_workspaces(self) -> List[WorkspaceActive]:
291346
sql = text(
292347
"""
293348
SELECT
294-
id, name, is_active
349+
w.id, w.name, s.active_workspace_id
350+
FROM workspaces w
351+
LEFT JOIN sessions s ON w.id = s.active_workspace_id
352+
"""
353+
)
354+
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
355+
return workspaces
356+
357+
async def get_workspace_by_name(self, name: str) -> List[Workspace]:
358+
sql = text(
359+
"""
360+
SELECT
361+
id, name
295362
FROM workspaces
296-
ORDER BY is_active DESC
363+
WHERE name = :name
297364
"""
298365
)
299-
workspaces = await self._execute_select_pydantic_model(Workspace, sql)
366+
conditions = {"name": name}
367+
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
300368
return workspaces
301369

370+
async def get_sessions(self) -> List[Session]:
371+
sql = text(
372+
"""
373+
SELECT
374+
id, active_workspace_id, last_update
375+
FROM sessions
376+
"""
377+
)
378+
sessions = await self._execute_select_pydantic_model(Session, sql)
379+
return sessions
380+
381+
async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
382+
sql = text(
383+
"""
384+
SELECT
385+
w.id, w.name, s.id as session_id, s.last_update
386+
FROM sessions s
387+
INNER JOIN workspaces w ON w.id = s.active_workspace_id
388+
"""
389+
)
390+
active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql)
391+
return active_workspace[0] if active_workspace else None
392+
302393

303394
def init_db_sync(db_path: Optional[str] = None):
304395
"""DB will be initialized in the constructor in case it doesn't exist."""
@@ -320,5 +411,22 @@ def init_db_sync(db_path: Optional[str] = None):
320411
logger.info("DB initialized successfully.")
321412

322413

414+
def init_session_if_not_exists(db_path: Optional[str] = None):
415+
import datetime
416+
db_reader = DbReader(db_path)
417+
sessions = asyncio.run(db_reader.get_sessions())
418+
# If there are no sessions, create a new one
419+
# TODO: For the moment there's a single session. If it already exists, we don't create a new one
420+
if not sessions:
421+
session = Session(
422+
id=str(uuid.uuid4()),
423+
active_workspace_id="1",
424+
last_update=datetime.datetime.now(datetime.timezone.utc)
425+
)
426+
db_recorder = DbRecorder(db_path)
427+
asyncio.run(db_recorder.update_session(session))
428+
logger.info("Session in DB initialized successfully.")
429+
430+
323431
if __name__ == "__main__":
324432
init_db_sync()

src/codegate/db/models.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import Any, Optional
23

34
import pydantic
@@ -26,7 +27,7 @@ class Prompt(pydantic.BaseModel):
2627
provider: Optional[Any]
2728
request: Any
2829
type: Any
29-
workspace_id: Optional[Any]
30+
workspace_id: Optional[str]
3031

3132

3233
class Setting(pydantic.BaseModel):
@@ -39,9 +40,14 @@ class Setting(pydantic.BaseModel):
3940

4041

4142
class Workspace(pydantic.BaseModel):
42-
id: Any
43+
id: str
4344
name: str
44-
is_active: bool = False
45+
46+
47+
class Session(pydantic.BaseModel):
48+
id: str
49+
active_workspace_id: str
50+
last_update: datetime.datetime
4551

4652

4753
# Models for select queries
@@ -73,3 +79,16 @@ class GetPromptWithOutputsRow(pydantic.BaseModel):
7379
output_id: Optional[Any]
7480
output: Optional[Any]
7581
output_timestamp: Optional[Any]
82+
83+
84+
class WorkspaceActive(pydantic.BaseModel):
85+
id: str
86+
name: str
87+
active_workspace_id: Optional[str]
88+
89+
90+
class ActiveWorkspace(pydantic.BaseModel):
91+
id: str
92+
name: str
93+
session_id: str
94+
last_update: datetime.datetime

src/codegate/pipeline/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def add_input_request(
135135
provider=provider,
136136
type="fim" if is_fim_request else "chat",
137137
request=request_str,
138-
workspace_id="1", # TODO: This is a placeholder for now, using default workspace
138+
workspace_id=None,
139139
)
140140
# Uncomment the below to debug the input
141141
# logger.debug(f"Added input request to context: {self.input_request}")

0 commit comments

Comments
 (0)