Skip to content

Allow workspace operations to work without caring about the case #736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def list_workspaces() -> v1_models.ListWorkspacesResponse:
"""List all workspaces."""
wslist = await wscrud.get_workspaces()

resp = v1_models.ListWorkspacesResponse.from_db_workspaces_active(wslist)
resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist)

return resp

Expand Down
9 changes: 4 additions & 5 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ class ListWorkspacesResponse(pydantic.BaseModel):
workspaces: list[Workspace]

@classmethod
def from_db_workspaces_active(
cls, db_workspaces: List[db_models.WorkspaceActive]
def from_db_workspaces_with_sessioninfo(
cls, db_workspaces: List[db_models.WorkspaceWithSessionInfo]
) -> "ListWorkspacesResponse":
return cls(
workspaces=[
Workspace(name=ws.name, is_active=ws.active_workspace_id is not None)
for ws in db_workspaces
Workspace(name=ws.name, is_active=ws.session_id is not None) for ws in db_workspaces
]
)

@classmethod
def from_db_workspaces(
cls, db_workspaces: List[db_models.Workspace]
cls, db_workspaces: List[db_models.WorkspaceRow]
) -> "ListWorkspacesResponse":
return cls(workspaces=[Workspace(name=ws.name, is_active=False) for ws in db_workspaces])

Expand Down
41 changes: 22 additions & 19 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
Alert,
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
Output,
Prompt,
Session,
Workspace,
WorkspaceActive,
WorkspaceRow,
WorkspaceWithSessionInfo,
)
from codegate.pipeline.base import PipelineContext

Expand Down Expand Up @@ -263,15 +264,17 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Workspace:
async def add_workspace(self, workspace_name: str) -> WorkspaceRow:
"""Add a new workspace to the DB.

This handles validation and insertion of a new workspace.

It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, custom_instructions=None)
workspace = WorkspaceRow(
id=str(uuid.uuid4()), name=workspace_name, custom_instructions=None
)
sql = text(
"""
INSERT INTO workspaces (id, name)
Expand All @@ -289,7 +292,7 @@ async def add_workspace(self, workspace_name: str) -> Workspace:
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_workspace(self, workspace: Workspace) -> Workspace:
async def update_workspace(self, workspace: WorkspaceRow) -> WorkspaceRow:
sql = text(
"""
UPDATE workspaces SET
Expand Down Expand Up @@ -319,7 +322,7 @@ async def update_session(self, session: Session) -> Optional[Session]:
active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True)
return active_session

async def soft_delete_workspace(self, workspace: Workspace) -> Optional[Workspace]:
async def soft_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]:
sql = text(
"""
UPDATE workspaces
Expand All @@ -333,7 +336,7 @@ async def soft_delete_workspace(self, workspace: Workspace) -> Optional[Workspac
)
return deleted_workspace

async def hard_delete_workspace(self, workspace: Workspace) -> Optional[Workspace]:
async def hard_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]:
sql = text(
"""
DELETE FROM workspaces
Expand All @@ -346,7 +349,7 @@ async def hard_delete_workspace(self, workspace: Workspace) -> Optional[Workspac
)
return deleted_workspace

async def recover_workspace(self, workspace: Workspace) -> Optional[Workspace]:
async def recover_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]:
sql = text(
"""
UPDATE workspaces
Expand Down Expand Up @@ -460,20 +463,20 @@ async def get_alerts_with_prompt_and_output(
)
return prompts

async def get_workspaces(self) -> List[WorkspaceActive]:
async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
sql = text(
"""
SELECT
w.id, w.name, s.active_workspace_id
w.id, w.name, s.id as session_id
FROM workspaces w
LEFT JOIN sessions s ON w.id = s.active_workspace_id
WHERE w.deleted_at IS NULL
"""
)
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
workspaces = await self._execute_select_pydantic_model(WorkspaceWithSessionInfo, sql)
return workspaces

async def get_archived_workspaces(self) -> List[Workspace]:
async def get_archived_workspaces(self) -> List[WorkspaceRow]:
sql = text(
"""
SELECT
Expand All @@ -483,10 +486,10 @@ async def get_archived_workspaces(self) -> List[Workspace]:
ORDER BY deleted_at DESC
"""
)
workspaces = await self._execute_select_pydantic_model(Workspace, sql)
workspaces = await self._execute_select_pydantic_model(WorkspaceRow, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]:
sql = text(
"""
SELECT
Expand All @@ -495,13 +498,13 @@ async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
WHERE name = :name AND deleted_at IS NULL
"""
)
conditions = {"name": name}
conditions = GetWorkspaceByNameConditions(name=name).get_conditions()
workspaces = await self._exec_select_conditions_to_pydantic(
Workspace, sql, conditions, should_raise=True
WorkspaceRow, sql, conditions, should_raise=True
)
return workspaces[0] if workspaces else None

async def get_archived_workspace_by_name(self, name: str) -> Optional[Workspace]:
async def get_archived_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]:
sql = text(
"""
SELECT
Expand All @@ -510,9 +513,9 @@ async def get_archived_workspace_by_name(self, name: str) -> Optional[Workspace]
WHERE name = :name AND deleted_at IS NOT NULL
"""
)
conditions = {"name": name}
conditions = GetWorkspaceByNameConditions(name=name).get_conditions()
workspaces = await self._exec_select_conditions_to_pydantic(
Workspace, sql, conditions, should_raise=True
WorkspaceRow, sql, conditions, should_raise=True
)
return workspaces[0] if workspaces else None

Expand Down
37 changes: 30 additions & 7 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,34 @@ class Prompt(BaseModel):
workspace_id: Optional[str]


WorskpaceNameStr = Annotated[
WorkspaceNameStr = Annotated[
str,
StringConstraints(
strip_whitespace=True, to_lower=True, pattern=r"^[a-zA-Z0-9_-]+$", strict=True
),
]


class Workspace(BaseModel):
class WorkspaceRow(BaseModel):
"""A workspace row entry.

Since our model currently includes instructions
in the same table, this is returned as a single
object.
"""

id: str
name: WorskpaceNameStr
name: WorkspaceNameStr
custom_instructions: Optional[str]


class GetWorkspaceByNameConditions(BaseModel):
name: WorkspaceNameStr

def get_conditions(self):
return {"name": self.name}


class Session(BaseModel):
id: str
active_workspace_id: str
Expand Down Expand Up @@ -81,15 +95,24 @@ class GetPromptWithOutputsRow(BaseModel):
output_timestamp: Optional[Any]


class WorkspaceActive(BaseModel):
class WorkspaceWithSessionInfo(BaseModel):
"""Returns a workspace ID with an optional
session ID. If the session ID is None, then
the workspace is not active.
"""

id: str
name: str
active_workspace_id: Optional[str]
name: WorkspaceNameStr
session_id: Optional[str]


class ActiveWorkspace(BaseModel):
"""Returns a full active workspace object with the
with the session information.
"""

id: str
name: str
name: WorkspaceNameStr
custom_instructions: Optional[str]
session_id: str
last_update: datetime.datetime
2 changes: 1 addition & 1 deletion src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def _list_workspaces(self, flags: Dict[str, str], args: List[str]) -> str:
respond_str = ""
for workspace in workspaces:
respond_str += f"- {workspace.name}"
if workspace.active_workspace_id:
if workspace.session_id:
respond_str += " **(active)**"
respond_str += "\n"
return respond_str
Expand Down
22 changes: 12 additions & 10 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple

from codegate.db.connection import DbReader, DbRecorder
from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive
from codegate.db.models import ActiveWorkspace, Session, WorkspaceRow, WorkspaceWithSessionInfo


class WorkspaceCrudError(Exception):
Expand All @@ -28,7 +28,7 @@ class WorkspaceCrud:
def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> Workspace:
async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
"""
Add a workspace

Expand All @@ -43,7 +43,9 @@ async def add_workspace(self, new_workspace_name: str) -> Workspace:
workspace_created = await db_recorder.add_workspace(new_workspace_name)
return workspace_created

async def rename_workspace(self, old_workspace_name: str, new_workspace_name: str) -> Workspace:
async def rename_workspace(
self, old_workspace_name: str, new_workspace_name: str
) -> WorkspaceRow:
"""
Rename a workspace

Expand All @@ -65,19 +67,19 @@ async def rename_workspace(self, old_workspace_name: str, new_workspace_name: st
if not ws:
raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
db_recorder = DbRecorder()
new_ws = Workspace(
new_ws = WorkspaceRow(
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
)
workspace_renamed = await db_recorder.update_workspace(new_ws)
return workspace_renamed

async def get_workspaces(self) -> List[WorkspaceActive]:
async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
"""
Get all workspaces
"""
return await self._db_reader.get_workspaces()

async def get_archived_workspaces(self) -> List[Workspace]:
async def get_archived_workspaces(self) -> List[WorkspaceRow]:
"""
Get all archived workspaces
"""
Expand All @@ -91,7 +93,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:

async def _is_workspace_active(
self, workspace_name: str
) -> Tuple[bool, Optional[Session], Optional[Workspace]]:
) -> Tuple[bool, Optional[Session], Optional[WorkspaceRow]]:
"""
Check if the workspace is active alongside the session and workspace objects
"""
Expand Down Expand Up @@ -137,13 +139,13 @@ async def recover_workspace(self, workspace_name: str):

async def update_workspace_custom_instructions(
self, workspace_name: str, custom_instr_lst: List[str]
) -> Workspace:
) -> WorkspaceRow:
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not selected_workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")

custom_instructions = " ".join(custom_instr_lst)
workspace_update = Workspace(
workspace_update = WorkspaceRow(
id=selected_workspace.id,
name=selected_workspace.name,
custom_instructions=custom_instructions,
Expand Down Expand Up @@ -195,7 +197,7 @@ async def hard_delete_workspace(self, workspace_name: str):
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
return

async def get_workspace_by_name(self, workspace_name: str) -> Workspace:
async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow:
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
Expand Down
16 changes: 8 additions & 8 deletions tests/pipeline/workspace/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from codegate.db.models import Workspace as WorkspaceModel
from codegate.db.models import WorkspaceActive
from codegate.db.models import WorkspaceRow as WorkspaceModel
from codegate.db.models import WorkspaceWithSessionInfo
from codegate.pipeline.cli.commands import Workspace


Expand All @@ -17,18 +17,18 @@
(
[
# We'll make a MagicMock that simulates a workspace
# with 'name' attribute and 'active_workspace_id' set
WorkspaceActive(id="1", name="Workspace1", active_workspace_id="100")
# with 'name' attribute and 'session_id' set
WorkspaceWithSessionInfo(id="1", name="Workspace1", session_id="100")
],
"- Workspace1 **(active)**\n",
"- workspace1 **(active)**\n",
),
# Case 3: Multiple workspaces, second one active
(
[
WorkspaceActive(id="1", name="Workspace1", active_workspace_id=None),
WorkspaceActive(id="2", name="Workspace2", active_workspace_id="200"),
WorkspaceWithSessionInfo(id="1", name="Workspace1", session_id=None),
WorkspaceWithSessionInfo(id="2", name="Workspace2", session_id="200"),
],
"- Workspace1\n- Workspace2 **(active)**\n",
"- workspace1\n- workspace2 **(active)**\n",
),
],
)
Expand Down
Loading