Skip to content

Commit 147205a

Browse files
feat: Initial migration for Workspaces and pipeline step (#600)
* feat: Initial migration for Workspaces and pipeline step Related: #454 We noticed most of the incoming-requests which contain a code-snippet only list a relative path with respect to where the code editor is opened. This would make difficult to accurately distinguish between repositories in Codegate. For example, a user could open 2 different code Python repositorites in different session and both repositories contain a `pyproject.toml`. It would be impossible for Codegate to determine the real repository of the file only using the relative path. Hence, the initial implementation of Workspaces will rely on a pipeline step that is able to take commands a process them. Some commands could be: - List workspaces - Add workspace - Switch active workspace - Delete workspace It would be up to the user to select the desired active workspace. This PR introduces an initial migration for Workspaces and the pipeline step with the `list` command. * Reformatting changes * Make unique workspaces name * Introduced Sessions table and added add and activate commands * Formatting changes and unit tests * Classes separation into a different file
1 parent b4d719f commit 147205a

File tree

14 files changed

+589
-26
lines changed

14 files changed

+589
-26
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""introduce workspaces
2+
3+
Revision ID: 5c2f3eee5f90
4+
Revises: 30d0144e1a50
5+
Create Date: 2025-01-15 19:27:08.230296
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision: str = "5c2f3eee5f90"
15+
down_revision: Union[str, None] = "30d0144e1a50"
16+
branch_labels: Union[str, Sequence[str], None] = None
17+
depends_on: Union[str, Sequence[str], None] = None
18+
19+
20+
def upgrade() -> None:
21+
# Workspaces table
22+
op.execute(
23+
"""
24+
CREATE TABLE workspaces (
25+
id TEXT PRIMARY KEY, -- UUID stored as TEXT
26+
name TEXT NOT NULL,
27+
UNIQUE (name)
28+
);
29+
"""
30+
)
31+
op.execute("INSERT INTO workspaces (id, name) VALUES ('1', 'default');")
32+
# Sessions table
33+
op.execute(
34+
"""
35+
CREATE TABLE sessions (
36+
id TEXT PRIMARY KEY, -- UUID stored as TEXT
37+
active_workspace_id TEXT NOT NULL,
38+
last_update DATETIME NOT NULL,
39+
FOREIGN KEY (active_workspace_id) REFERENCES workspaces(id)
40+
);
41+
"""
42+
)
43+
# Alter table prompts
44+
op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);")
45+
op.execute("UPDATE prompts SET workspace_id = '1';")
46+
# Create index for workspace_id
47+
op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);")
48+
# Create index for session_id
49+
op.execute("CREATE INDEX idx_sessions_workspace_id ON sessions (active_workspace_id);")
50+
51+
52+
def downgrade() -> None:
53+
# Drop the index for workspace_id
54+
op.execute("DROP INDEX IF EXISTS idx_prompts_workspace_id;")
55+
op.execute("DROP INDEX IF EXISTS idx_sessions_workspace_id;")
56+
# Remove the workspace_id column from prompts table
57+
op.execute("ALTER TABLE prompts DROP COLUMN workspace_id;")
58+
# Drop the sessions table
59+
op.execute("DROP TABLE IF EXISTS sessions;")
60+
# Drop the workspaces table
61+
op.execute("DROP TABLE IF EXISTS workspaces;")

src/codegate/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_VERSION = "dev"
1111
_DESC = "CodeGate - A Generative AI security gateway."
1212

13+
1314
def __get_version_and_description() -> tuple[str, str]:
1415
try:
1516
version = metadata.version("codegate")
@@ -19,6 +20,7 @@ def __get_version_and_description() -> tuple[str, str]:
1920
description = _DESC
2021
return version, description
2122

23+
2224
__version__, __description__ = __get_version_and_description()
2325

2426
__all__ = ["Config", "ConfigurationError", "LogFormat", "LogLevel", "setup_logging"]

src/codegate/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from codegate.ca.codegate_ca import CertificateAuthority
1515
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
1616
from codegate.config import Config, ConfigurationError
17-
from codegate.db.connection import init_db_sync
17+
from codegate.db.connection import init_db_sync, init_session_if_not_exists
1818
from codegate.pipeline.factory import PipelineFactory
1919
from codegate.pipeline.secrets.manager import SecretsManager
2020
from codegate.providers.copilot.provider import CopilotProvider
@@ -307,6 +307,7 @@ def serve(
307307
logger = structlog.get_logger("codegate").bind(origin="cli")
308308

309309
init_db_sync(cfg.db_path)
310+
init_session_if_not_exists(cfg.db_path)
310311

311312
# Check certificates and create CA if necessary
312313
logger.info("Checking certificates and creating CA if needed")

src/codegate/db/connection.py

+136-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
import asyncio
22
import json
3+
import uuid
34
from pathlib import Path
45
from typing import List, Optional, Type
56

67
import structlog
78
from alembic import command as alembic_command
89
from alembic.config import Config as AlembicConfig
9-
from pydantic import BaseModel
10-
from sqlalchemy import TextClause, text
10+
from pydantic import BaseModel, ValidationError
11+
from sqlalchemy import CursorResult, TextClause, text
1112
from sqlalchemy.exc import OperationalError
1213
from sqlalchemy.ext.asyncio import create_async_engine
1314

1415
from codegate.db.fim_cache import FimCache
1516
from codegate.db.models import (
17+
ActiveWorkspace,
1618
Alert,
1719
GetAlertsWithPromptAndOutputRow,
1820
GetPromptWithOutputsRow,
1921
Output,
2022
Prompt,
23+
Session,
24+
Workspace,
25+
WorkspaceActive,
2126
)
2227
from codegate.pipeline.base import PipelineContext
2328

@@ -75,10 +80,14 @@ async def _execute_update_pydantic_model(
7580
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
7681
if prompt_params is None:
7782
return None
83+
# Get the active workspace to store the request
84+
active_workspace = await DbReader().get_active_workspace()
85+
workspace_id = active_workspace.id if active_workspace else "1"
86+
prompt_params.workspace_id = workspace_id
7887
sql = text(
7988
"""
80-
INSERT INTO prompts (id, timestamp, provider, request, type)
81-
VALUES (:id, :timestamp, :provider, :request, :type)
89+
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
90+
VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id)
8291
RETURNING *
8392
"""
8493
)
@@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
223232
except Exception as e:
224233
logger.error(f"Failed to record context: {context}.", error=str(e))
225234

235+
async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
236+
try:
237+
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
238+
except ValidationError as e:
239+
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
240+
return None
241+
242+
sql = text(
243+
"""
244+
INSERT INTO workspaces (id, name)
245+
VALUES (:id, :name)
246+
RETURNING *
247+
"""
248+
)
249+
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
250+
return added_workspace
251+
252+
async def update_session(self, session: Session) -> Optional[Session]:
253+
sql = text(
254+
"""
255+
INSERT INTO sessions (id, active_workspace_id, last_update)
256+
VALUES (:id, :active_workspace_id, :last_update)
257+
ON CONFLICT (id) DO UPDATE SET
258+
active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
259+
WHERE id = excluded.id
260+
RETURNING *
261+
"""
262+
)
263+
# We only pass an object to respect the signature of the function
264+
active_session = await self._execute_update_pydantic_model(session, sql)
265+
return active_session
266+
226267

227268
class DbReader(DbCodeGate):
228269

229270
def __init__(self, sqlite_path: Optional[str] = None):
230271
super().__init__(sqlite_path)
231272

273+
async def _dump_result_to_pydantic_model(
274+
self, model_type: Type[BaseModel], result: CursorResult
275+
) -> Optional[List[BaseModel]]:
276+
try:
277+
if not result:
278+
return None
279+
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
280+
return rows
281+
except Exception as e:
282+
logger.error(f"Failed to dump to pydantic model: {model_type}.", error=str(e))
283+
return None
284+
232285
async def _execute_select_pydantic_model(
233286
self, model_type: Type[BaseModel], sql_command: TextClause
234-
) -> Optional[BaseModel]:
287+
) -> Optional[List[BaseModel]]:
235288
async with self._async_db_engine.begin() as conn:
236289
try:
237290
result = await conn.execute(sql_command)
238-
if not result:
239-
return None
240-
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
241-
return rows
291+
return await self._dump_result_to_pydantic_model(model_type, result)
242292
except Exception as e:
243293
logger.error(f"Failed to select model: {model_type}.", error=str(e))
244294
return None
245295

296+
async def _exec_select_conditions_to_pydantic(
297+
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
298+
) -> Optional[List[BaseModel]]:
299+
async with self._async_db_engine.begin() as conn:
300+
try:
301+
result = await conn.execute(sql_command, conditions)
302+
return await self._dump_result_to_pydantic_model(model_type, result)
303+
except Exception as e:
304+
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
305+
return None
306+
246307
async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
247308
sql = text(
248309
"""
@@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
286347
prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql)
287348
return prompts
288349

350+
async def get_workspaces(self) -> List[WorkspaceActive]:
351+
sql = text(
352+
"""
353+
SELECT
354+
w.id, w.name, s.active_workspace_id
355+
FROM workspaces w
356+
LEFT JOIN sessions s ON w.id = s.active_workspace_id
357+
"""
358+
)
359+
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
360+
return workspaces
361+
362+
async def get_workspace_by_name(self, name: str) -> List[Workspace]:
363+
sql = text(
364+
"""
365+
SELECT
366+
id, name
367+
FROM workspaces
368+
WHERE name = :name
369+
"""
370+
)
371+
conditions = {"name": name}
372+
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
373+
return workspaces[0] if workspaces else None
374+
375+
async def get_sessions(self) -> List[Session]:
376+
sql = text(
377+
"""
378+
SELECT
379+
id, active_workspace_id, last_update
380+
FROM sessions
381+
"""
382+
)
383+
sessions = await self._execute_select_pydantic_model(Session, sql)
384+
return sessions
385+
386+
async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
387+
sql = text(
388+
"""
389+
SELECT
390+
w.id, w.name, s.id as session_id, s.last_update
391+
FROM sessions s
392+
INNER JOIN workspaces w ON w.id = s.active_workspace_id
393+
"""
394+
)
395+
active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql)
396+
return active_workspace[0] if active_workspace else None
397+
289398

290399
def init_db_sync(db_path: Optional[str] = None):
291400
"""DB will be initialized in the constructor in case it doesn't exist."""
@@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None):
307416
logger.info("DB initialized successfully.")
308417

309418

419+
def init_session_if_not_exists(db_path: Optional[str] = None):
420+
import datetime
421+
422+
db_reader = DbReader(db_path)
423+
sessions = asyncio.run(db_reader.get_sessions())
424+
# If there are no sessions, create a new one
425+
# TODO: For the moment there's a single session. If it already exists, we don't create a new one
426+
if not sessions:
427+
session = Session(
428+
id=str(uuid.uuid4()),
429+
active_workspace_id="1",
430+
last_update=datetime.datetime.now(datetime.timezone.utc),
431+
)
432+
db_recorder = DbRecorder(db_path)
433+
asyncio.run(db_recorder.update_session(session))
434+
logger.info("Session in DB initialized successfully.")
435+
436+
310437
if __name__ == "__main__":
311438
init_db_sync()

0 commit comments

Comments
 (0)