Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

feat: initial work on endpoints for creating/updating workspace config #1107

Merged
merged 23 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7cd8ec5
feat: initial work on endpoints for creating/updating
alex-mcgovern Feb 19, 2025
b0141ab
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Feb 19, 2025
57ec255
fix: return newly created `FullWorkspace` from POST /api/v1/workspaces
alex-mcgovern Feb 19, 2025
cd83f56
formatting
alex-mcgovern Feb 19, 2025
efbfa64
test: create workspace with config happy path
alex-mcgovern Feb 19, 2025
7e91e5e
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Feb 19, 2025
358f7e5
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Feb 20, 2025
af7251d
1 db per test
alex-mcgovern Feb 20, 2025
b87c276
test: basic happy path test for create/update workspace config
alex-mcgovern Feb 21, 2025
033e16b
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Feb 21, 2025
37f7e78
fix failing test
alex-mcgovern Feb 21, 2025
56cc0de
chore: fmt pass
alex-mcgovern Feb 21, 2025
1a69daf
fix: internal server error when no config passed
alex-mcgovern Feb 21, 2025
7de81b5
tidy up
alex-mcgovern Feb 21, 2025
36a9551
test: more integration tests
alex-mcgovern Mar 4, 2025
2a50d9f
chore: tidy ups
alex-mcgovern Mar 4, 2025
76bb197
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Mar 4, 2025
3b4787d
chore: revert openapi changes
alex-mcgovern Mar 4, 2025
d329538
lint fixes
alex-mcgovern Mar 4, 2025
bb4f838
Merge branch 'main' into feat/configure-workspace-endpoint
alex-mcgovern Mar 4, 2025
e5eb6b4
remove manual rollbacks, ensure re-raising all exceptions
alex-mcgovern Mar 4, 2025
7bcab23
Merge branch 'main' of github.com:stacklok/codegate into feat/configu…
alex-mcgovern Mar 4, 2025
baa7aa6
Merge branch 'main' into feat/configure-workspace-endpoint
alex-mcgovern Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,12 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status
@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
async def create_workspace(
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
) -> v1_models.FullWorkspace:
"""Create a new workspace."""
if request.rename_to is not None:
return await rename_workspace(request)
return await create_new_workspace(request)


async def create_new_workspace(
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
# Input validation is done in the model
try:
_ = await wscrud.add_workspace(request.name)
_ = await wscrud.add_workspace(
request.name, request.config.custom_instructions, request.config.muxing_rules
)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
Expand All @@ -280,15 +273,26 @@ async def create_new_workspace(
return v1_models.Workspace(name=request.name, is_active=False)


async def rename_workspace(
@v1.put(
"/workspaces/{workspace_name}",
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
status_code=201,
)
async def update_workspace(
workspace_name: str,
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
) -> v1_models.FullWorkspace:
"""Update a workspace."""
try:
_ = await wscrud.rename_workspace(request.name, request.rename_to)
workspace_row, mux_rules = await wscrud.update_workspace(
workspace_name,
request.name,
request.config.custom_instructions,
request.config.muxing_rules,
)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
raise HTTPException(
status_code=400,
Expand All @@ -302,7 +306,13 @@ async def rename_workspace(
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return v1_models.Workspace(name=request.rename_to, is_active=False)
return v1_models.FullWorkspace(
name=workspace_row.name,
config=v1_models.WorkspaceConfig(
custom_instructions=workspace_row.custom_instructions,
muxing_rules=[mux_models.MuxRule.try_from_db_model(mux_rule) for mux_rule in mux_rules],
),
)


@v1.delete(
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_db_workspaces(


class WorkspaceConfig(pydantic.BaseModel):
system_prompt: str
custom_instructions: str

muxing_rules: List[mux_models.MuxRule]

Expand Down
38 changes: 33 additions & 5 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from sqlalchemy import CursorResult, TextClause, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

from codegate.db.fim_cache import FimCache
from codegate.db.models import (
Expand Down Expand Up @@ -610,10 +611,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id(
# If trigger category is None we want to get all alerts
trigger_category = trigger_category if trigger_category else "%"
conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category}
rows: List[IntermediatePromptWithOutputUsageAlerts] = (
await self._exec_select_conditions_to_pydantic(
IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True
)
rows: List[
IntermediatePromptWithOutputUsageAlerts
] = await self._exec_select_conditions_to_pydantic(
IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True
)

prompts_dict: Dict[str, GetPromptWithOutputsRow] = {}
Expand Down Expand Up @@ -871,6 +872,33 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
return muxes


class DbTransaction:
def __init__(self):
self._session = None

async def __aenter__(self):
self._session = sessionmaker(
bind=DbCodeGate()._async_db_engine,
class_=AsyncSession,
expire_on_commit=False,
)()
await self._session.begin()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
await self._session.rollback()
else:
await self._session.commit()
await self._session.close()

async def commit(self):
await self._session.commit()

async def rollback(self):
await self._session.rollback()


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
current_dir = Path(__file__).parent
Expand Down
14 changes: 14 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pydantic

from codegate.clients.clients import ClientType
from codegate.db.models import MuxRule as DbMuxRule


class MuxMatcherType(str, Enum):
Expand Down Expand Up @@ -36,6 +37,19 @@ class MuxRule(pydantic.BaseModel):
# this depends on the matcher type.
matcher: Optional[str] = None

@classmethod
def try_from_db_model(cls, db_model: DbMuxRule) -> "MuxRule":
try:
return cls(
provider_name=db_model.provider_endpoint_name,
provider_id=db_model.provider_endpoint_id,
model=db_model.provider_model_name,
matcher_type=MuxMatcherType(db_model.matcher_type),
matcher=db_model.matcher_blob,
)
except Exception as e:
raise ValueError(f"Error converting from DbMuxRule: {e}")


class ThingToMatchMux(pydantic.BaseModel):
"""
Expand Down
5 changes: 1 addition & 4 deletions src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def help(self) -> str:


class CodegateCommandSubcommand(CodegateCommand):

@property
@abstractmethod
def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]:
Expand Down Expand Up @@ -174,7 +173,6 @@ async def run(self, args: List[str]) -> str:


class Workspace(CodegateCommandSubcommand):

def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()

Expand Down Expand Up @@ -258,7 +256,7 @@ async def _rename_workspace(self, flags: Dict[str, str], args: List[str]) -> str
)

try:
await self.workspace_crud.rename_workspace(old_workspace_name, new_workspace_name)
await self.workspace_crud.update_workspace(old_workspace_name, new_workspace_name)
except crud.WorkspaceDoesNotExistError:
return f"Workspace **{old_workspace_name}** does not exist"
except AlreadyExistsError:
Expand Down Expand Up @@ -410,7 +408,6 @@ def help(self) -> str:


class CustomInstructions(CodegateCommandSubcommand):

def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()

Expand Down
108 changes: 85 additions & 23 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from uuid import uuid4 as uuid

from codegate.db import models as db_models
from codegate.db.connection import DbReader, DbRecorder
from codegate.db.connection import DbReader, DbRecorder, DbTransaction
from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher

Expand All @@ -16,6 +16,10 @@ class WorkspaceDoesNotExistError(WorkspaceCrudError):
pass


class WorkspaceNameAlreadyInUseError(WorkspaceCrudError):
pass


class WorkspaceAlreadyActiveError(WorkspaceCrudError):
pass

Expand All @@ -31,34 +35,61 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError):


class WorkspaceCrud:

def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow:
async def add_workspace(
self,
new_workspace_name: str,
custom_instructions: Optional[str] = None,
muxing_rules: Optional[List[mux_models.MuxRule]] = None,
) -> db_models.WorkspaceRow:
"""
Add a workspace

Args:
name (str): The name of the workspace
new_workspace_name (str): The name of the workspace
system_prompt (Optional[str]): The system prompt for the workspace
muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace
"""
if new_workspace_name == "":
raise WorkspaceCrudError("Workspace name cannot be empty.")
if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS:
raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.")
db_recorder = DbRecorder()
workspace_created = await db_recorder.add_workspace(new_workspace_name)
return workspace_created

async def rename_workspace(
self, old_workspace_name: str, new_workspace_name: str
) -> db_models.WorkspaceRow:
async with DbTransaction() as transaction:
try:
db_recorder = DbRecorder()
workspace_created = await db_recorder.add_workspace(new_workspace_name)

if custom_instructions:
workspace_created.custom_instructions = custom_instructions
await db_recorder.update_workspace(workspace_created)

if muxing_rules:
await self.set_muxes(new_workspace_name, muxing_rules)

await transaction.commit()
return workspace_created
except Exception as e:
await transaction.rollback()
raise WorkspaceCrudError(f"Error adding workspace {new_workspace_name}: {str(e)}")

async def update_workspace(
self,
old_workspace_name: str,
new_workspace_name: str,
custom_instructions: Optional[str] = None,
muxing_rules: Optional[List[mux_models.MuxRule]] = None,
) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]:
"""
Rename a workspace
Update a workspace

Args:
old_name (str): The old name of the workspace
new_name (str): The new name of the workspace
old_workspace_name (str): The old name of the workspace
new_workspace_name (str): The new name of the workspace
system_prompt (Optional[str]): The system prompt for the workspace
muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace
"""
if new_workspace_name == "":
raise WorkspaceCrudError("Workspace name cannot be empty.")
Expand All @@ -70,15 +101,40 @@ async def rename_workspace(
raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.")
if old_workspace_name == new_workspace_name:
raise WorkspaceCrudError("Old and new workspace names are the same.")
ws = await self._db_reader.get_workspace_by_name(old_workspace_name)
if not ws:
raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
db_recorder = DbRecorder()
new_ws = db_models.WorkspaceRow(
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
)
workspace_renamed = await db_recorder.update_workspace(new_ws)
return workspace_renamed

async with DbTransaction() as transaction:
try:
ws = await self._db_reader.get_workspace_by_name(old_workspace_name)
if not ws:
raise WorkspaceDoesNotExistError(
f"Workspace {old_workspace_name} does not exist."
)

existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name)
if existing_ws:
raise WorkspaceNameAlreadyInUseError(
f"Workspace name {new_workspace_name} is already in use."
)

db_recorder = DbRecorder()
new_ws = db_models.WorkspaceRow(
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
)
workspace_renamed = await db_recorder.update_workspace(new_ws)

if custom_instructions:
workspace_renamed.custom_instructions = custom_instructions
await db_recorder.update_workspace(workspace_renamed)

mux_rules = []
if muxing_rules:
mux_rules = await self.set_muxes(new_workspace_name, muxing_rules)

await transaction.commit()
return workspace_renamed, mux_rules
except Exception as e:
await transaction.rollback()
raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}")

async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]:
"""
Expand Down Expand Up @@ -240,7 +296,9 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]:

return muxes

async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None:
async def set_muxes(
self, workspace_name: str, muxes: List[mux_models.MuxRule]
) -> List[db_models.MuxRule]:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -261,6 +319,7 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
muxes_with_routes.append((mux, route))

matchers: List[rulematcher.MuxingRuleMatcher] = []
dbmuxes: List[db_models.MuxRule] = []

for mux, route in muxes_with_routes:
new_mux = db_models.MuxRule(
Expand All @@ -273,6 +332,7 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
priority=priority,
)
dbmux = await db_recorder.add_mux(new_mux)
dbmuxes.append(dbmux)

matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route))

Expand All @@ -282,6 +342,8 @@ async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> Non
mux_registry = await rulematcher.get_muxing_rules_registry()
await mux_registry.set_ws_rules(workspace_name, matchers)

return dbmuxes

async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux

Expand Down
Loading