diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 243e78ee..ff02dfb1 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -27,7 +27,7 @@ async def list_workspaces() -> v1_models.ListWorkspacesResponse: """List all workspaces.""" wslist = await wscrud.get_workspaces() - resp = v1_models.ListWorkspacesResponse.from_db_workspaces_active(wslist) + resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) return resp diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index efdc35ae..ee86c208 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -23,19 +23,18 @@ class ListWorkspacesResponse(pydantic.BaseModel): workspaces: list[Workspace] @classmethod - def from_db_workspaces_active( - cls, db_workspaces: List[db_models.WorkspaceActive] + def from_db_workspaces_with_sessioninfo( + cls, db_workspaces: List[db_models.WorkspaceWithSessionInfo] ) -> "ListWorkspacesResponse": return cls( workspaces=[ - Workspace(name=ws.name, is_active=ws.active_workspace_id is not None) - for ws in db_workspaces + Workspace(name=ws.name, is_active=ws.session_id is not None) for ws in db_workspaces ] ) @classmethod def from_db_workspaces( - cls, db_workspaces: List[db_models.Workspace] + cls, db_workspaces: List[db_models.WorkspaceRow] ) -> "ListWorkspacesResponse": return cls(workspaces=[Workspace(name=ws.name, is_active=False) for ws in db_workspaces]) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 3ca9bd0a..51b47a09 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -19,11 +19,12 @@ Alert, GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow, + GetWorkspaceByNameConditions, Output, Prompt, Session, - Workspace, - WorkspaceActive, + WorkspaceRow, + WorkspaceWithSessionInfo, ) from codegate.pipeline.base import PipelineContext @@ -263,7 +264,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: except Exception as e: logger.error(f"Failed to record context: {context}.", error=str(e)) - async def add_workspace(self, workspace_name: str) -> Workspace: + async def add_workspace(self, workspace_name: str) -> WorkspaceRow: """Add a new workspace to the DB. This handles validation and insertion of a new workspace. @@ -271,7 +272,9 @@ async def add_workspace(self, workspace_name: str) -> Workspace: It may raise a ValidationError if the workspace name is invalid. or a AlreadyExistsError if the workspace already exists. """ - workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, custom_instructions=None) + workspace = WorkspaceRow( + id=str(uuid.uuid4()), name=workspace_name, custom_instructions=None + ) sql = text( """ INSERT INTO workspaces (id, name) @@ -289,7 +292,7 @@ async def add_workspace(self, workspace_name: str) -> Workspace: raise AlreadyExistsError(f"Workspace {workspace_name} already exists.") return added_workspace - async def update_workspace(self, workspace: Workspace) -> Workspace: + async def update_workspace(self, workspace: WorkspaceRow) -> WorkspaceRow: sql = text( """ UPDATE workspaces SET @@ -319,7 +322,7 @@ async def update_session(self, session: Session) -> Optional[Session]: active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True) return active_session - async def soft_delete_workspace(self, workspace: Workspace) -> Optional[Workspace]: + async def soft_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]: sql = text( """ UPDATE workspaces @@ -333,7 +336,7 @@ async def soft_delete_workspace(self, workspace: Workspace) -> Optional[Workspac ) return deleted_workspace - async def hard_delete_workspace(self, workspace: Workspace) -> Optional[Workspace]: + async def hard_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]: sql = text( """ DELETE FROM workspaces @@ -346,7 +349,7 @@ async def hard_delete_workspace(self, workspace: Workspace) -> Optional[Workspac ) return deleted_workspace - async def recover_workspace(self, workspace: Workspace) -> Optional[Workspace]: + async def recover_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]: sql = text( """ UPDATE workspaces @@ -460,20 +463,20 @@ async def get_alerts_with_prompt_and_output( ) return prompts - async def get_workspaces(self) -> List[WorkspaceActive]: + async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: sql = text( """ SELECT - w.id, w.name, s.active_workspace_id + w.id, w.name, s.id as session_id FROM workspaces w LEFT JOIN sessions s ON w.id = s.active_workspace_id WHERE w.deleted_at IS NULL """ ) - workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql) + workspaces = await self._execute_select_pydantic_model(WorkspaceWithSessionInfo, sql) return workspaces - async def get_archived_workspaces(self) -> List[Workspace]: + async def get_archived_workspaces(self) -> List[WorkspaceRow]: sql = text( """ SELECT @@ -483,10 +486,10 @@ async def get_archived_workspaces(self) -> List[Workspace]: ORDER BY deleted_at DESC """ ) - workspaces = await self._execute_select_pydantic_model(Workspace, sql) + workspaces = await self._execute_select_pydantic_model(WorkspaceRow, sql) return workspaces - async def get_workspace_by_name(self, name: str) -> Optional[Workspace]: + async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: sql = text( """ SELECT @@ -495,13 +498,13 @@ async def get_workspace_by_name(self, name: str) -> Optional[Workspace]: WHERE name = :name AND deleted_at IS NULL """ ) - conditions = {"name": name} + conditions = GetWorkspaceByNameConditions(name=name).get_conditions() workspaces = await self._exec_select_conditions_to_pydantic( - Workspace, sql, conditions, should_raise=True + WorkspaceRow, sql, conditions, should_raise=True ) return workspaces[0] if workspaces else None - async def get_archived_workspace_by_name(self, name: str) -> Optional[Workspace]: + async def get_archived_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: sql = text( """ SELECT @@ -510,9 +513,9 @@ async def get_archived_workspace_by_name(self, name: str) -> Optional[Workspace] WHERE name = :name AND deleted_at IS NOT NULL """ ) - conditions = {"name": name} + conditions = GetWorkspaceByNameConditions(name=name).get_conditions() workspaces = await self._exec_select_conditions_to_pydantic( - Workspace, sql, conditions, should_raise=True + WorkspaceRow, sql, conditions, should_raise=True ) return workspaces[0] if workspaces else None diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 0ea07fbb..366712b0 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -30,7 +30,7 @@ class Prompt(BaseModel): workspace_id: Optional[str] -WorskpaceNameStr = Annotated[ +WorkspaceNameStr = Annotated[ str, StringConstraints( strip_whitespace=True, to_lower=True, pattern=r"^[a-zA-Z0-9_-]+$", strict=True @@ -38,12 +38,26 @@ class Prompt(BaseModel): ] -class Workspace(BaseModel): +class WorkspaceRow(BaseModel): + """A workspace row entry. + + Since our model currently includes instructions + in the same table, this is returned as a single + object. + """ + id: str - name: WorskpaceNameStr + name: WorkspaceNameStr custom_instructions: Optional[str] +class GetWorkspaceByNameConditions(BaseModel): + name: WorkspaceNameStr + + def get_conditions(self): + return {"name": self.name} + + class Session(BaseModel): id: str active_workspace_id: str @@ -81,15 +95,24 @@ class GetPromptWithOutputsRow(BaseModel): output_timestamp: Optional[Any] -class WorkspaceActive(BaseModel): +class WorkspaceWithSessionInfo(BaseModel): + """Returns a workspace ID with an optional + session ID. If the session ID is None, then + the workspace is not active. + """ + id: str - name: str - active_workspace_id: Optional[str] + name: WorkspaceNameStr + session_id: Optional[str] class ActiveWorkspace(BaseModel): + """Returns a full active workspace object with the + with the session information. + """ + id: str - name: str + name: WorkspaceNameStr custom_instructions: Optional[str] session_id: str last_update: datetime.datetime diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index 104f8dac..da52c42d 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -168,7 +168,7 @@ async def _list_workspaces(self, flags: Dict[str, str], args: List[str]) -> str: respond_str = "" for workspace in workspaces: respond_str += f"- {workspace.name}" - if workspace.active_workspace_id: + if workspace.session_id: respond_str += " **(active)**" respond_str += "\n" return respond_str diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 7e0daa50..6d177522 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple from codegate.db.connection import DbReader, DbRecorder -from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive +from codegate.db.models import ActiveWorkspace, Session, WorkspaceRow, WorkspaceWithSessionInfo class WorkspaceCrudError(Exception): @@ -28,7 +28,7 @@ class WorkspaceCrud: def __init__(self): self._db_reader = DbReader() - async def add_workspace(self, new_workspace_name: str) -> Workspace: + async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow: """ Add a workspace @@ -43,7 +43,9 @@ async def add_workspace(self, new_workspace_name: str) -> Workspace: workspace_created = await db_recorder.add_workspace(new_workspace_name) return workspace_created - async def rename_workspace(self, old_workspace_name: str, new_workspace_name: str) -> Workspace: + async def rename_workspace( + self, old_workspace_name: str, new_workspace_name: str + ) -> WorkspaceRow: """ Rename a workspace @@ -65,19 +67,19 @@ async def rename_workspace(self, old_workspace_name: str, new_workspace_name: st if not ws: raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.") db_recorder = DbRecorder() - new_ws = Workspace( + new_ws = WorkspaceRow( id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions ) workspace_renamed = await db_recorder.update_workspace(new_ws) return workspace_renamed - async def get_workspaces(self) -> List[WorkspaceActive]: + async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]: """ Get all workspaces """ return await self._db_reader.get_workspaces() - async def get_archived_workspaces(self) -> List[Workspace]: + async def get_archived_workspaces(self) -> List[WorkspaceRow]: """ Get all archived workspaces """ @@ -91,7 +93,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: async def _is_workspace_active( self, workspace_name: str - ) -> Tuple[bool, Optional[Session], Optional[Workspace]]: + ) -> Tuple[bool, Optional[Session], Optional[WorkspaceRow]]: """ Check if the workspace is active alongside the session and workspace objects """ @@ -137,13 +139,13 @@ async def recover_workspace(self, workspace_name: str): async def update_workspace_custom_instructions( self, workspace_name: str, custom_instr_lst: List[str] - ) -> Workspace: + ) -> WorkspaceRow: selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not selected_workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") custom_instructions = " ".join(custom_instr_lst) - workspace_update = Workspace( + workspace_update = WorkspaceRow( id=selected_workspace.id, name=selected_workspace.name, custom_instructions=custom_instructions, @@ -195,7 +197,7 @@ async def hard_delete_workspace(self, workspace_name: str): raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") return - async def get_workspace_by_name(self, workspace_name: str) -> Workspace: + async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow: workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") diff --git a/tests/pipeline/workspace/test_workspace.py b/tests/pipeline/workspace/test_workspace.py index 32bf5fc6..c1b8ae96 100644 --- a/tests/pipeline/workspace/test_workspace.py +++ b/tests/pipeline/workspace/test_workspace.py @@ -2,8 +2,8 @@ import pytest -from codegate.db.models import Workspace as WorkspaceModel -from codegate.db.models import WorkspaceActive +from codegate.db.models import WorkspaceRow as WorkspaceModel +from codegate.db.models import WorkspaceWithSessionInfo from codegate.pipeline.cli.commands import Workspace @@ -17,18 +17,18 @@ ( [ # We'll make a MagicMock that simulates a workspace - # with 'name' attribute and 'active_workspace_id' set - WorkspaceActive(id="1", name="Workspace1", active_workspace_id="100") + # with 'name' attribute and 'session_id' set + WorkspaceWithSessionInfo(id="1", name="Workspace1", session_id="100") ], - "- Workspace1 **(active)**\n", + "- workspace1 **(active)**\n", ), # Case 3: Multiple workspaces, second one active ( [ - WorkspaceActive(id="1", name="Workspace1", active_workspace_id=None), - WorkspaceActive(id="2", name="Workspace2", active_workspace_id="200"), + WorkspaceWithSessionInfo(id="1", name="Workspace1", session_id=None), + WorkspaceWithSessionInfo(id="2", name="Workspace2", session_id="200"), ], - "- Workspace1\n- Workspace2 **(active)**\n", + "- workspace1\n- workspace2 **(active)**\n", ), ], )