Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 39ea5bf

Browse files
Some fixes on Persona CRUD (#1241)
- Make all the API requests interface with PersonaManager class - Add get personas in PersonaManager class - Validate personas names
1 parent e763971 commit 39ea5bf

File tree

4 files changed

+101
-16
lines changed

4 files changed

+101
-16
lines changed

src/codegate/api/v1.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
677677
async def list_personas() -> List[Persona]:
678678
"""List all personas."""
679679
try:
680-
personas = await dbreader.get_all_personas()
680+
personas = await persona_manager.get_all_personas()
681681
return personas
682682
except Exception:
683683
logger.exception("Error while getting personas")
@@ -688,15 +688,11 @@ async def list_personas() -> List[Persona]:
688688
async def get_persona(persona_name: str) -> Persona:
689689
"""Get a persona by name."""
690690
try:
691-
persona = await dbreader.get_persona_by_name(persona_name)
692-
if not persona:
693-
raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found")
691+
persona = await persona_manager.get_persona(persona_name)
694692
return persona
695-
except Exception as e:
696-
if isinstance(e, HTTPException):
697-
raise e
698-
logger.exception(f"Error while getting persona {persona_name}")
699-
raise HTTPException(status_code=500, detail="Internal server error")
693+
except PersonaDoesNotExistError:
694+
logger.exception("Error while getting persona")
695+
raise HTTPException(status_code=404, detail="Persona does not exist")
700696

701697

702698
@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201)
@@ -712,6 +708,15 @@ async def create_persona(request: v1_models.PersonaRequest) -> Persona:
712708
except AlreadyExistsError:
713709
logger.exception("Error while creating persona")
714710
raise HTTPException(status_code=409, detail="Persona already exists")
711+
except ValidationError:
712+
logger.exception("Error while creating persona")
713+
raise HTTPException(
714+
status_code=400,
715+
detail=(
716+
"Persona has invalid name, check is alphanumeric "
717+
"and only contains dashes and underscores"
718+
),
719+
)
715720
except Exception:
716721
logger.exception("Error while creating persona")
717722
raise HTTPException(status_code=500, detail="Internal server error")
@@ -735,6 +740,15 @@ async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequ
735740
except AlreadyExistsError:
736741
logger.exception("Error while updating persona")
737742
raise HTTPException(status_code=409, detail="Persona already exists")
743+
except ValidationError:
744+
logger.exception("Error while creating persona")
745+
raise HTTPException(
746+
status_code=400,
747+
detail=(
748+
"Persona has invalid name, check is alphanumeric "
749+
"and only contains dashes and underscores"
750+
),
751+
)
738752
except Exception:
739753
logger.exception("Error while updating persona")
740754
raise HTTPException(status_code=500, detail="Internal server error")

src/codegate/db/models.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
from typing import Annotated, Any, Dict, List, Optional
44

55
import numpy as np
6-
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
6+
import regex as re
7+
from pydantic import (
8+
BaseModel,
9+
BeforeValidator,
10+
ConfigDict,
11+
PlainSerializer,
12+
StringConstraints,
13+
field_validator,
14+
)
715

816

917
class AlertSeverity(str, Enum):
@@ -266,6 +274,8 @@ def nd_array_custom_serializer(x):
266274
PlainSerializer(nd_array_custom_serializer, return_type=str),
267275
]
268276

277+
VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$")
278+
269279

270280
class Persona(BaseModel):
271281
"""
@@ -276,6 +286,15 @@ class Persona(BaseModel):
276286
name: str
277287
description: str
278288

289+
@field_validator("name", mode="after")
290+
@classmethod
291+
def validate_persona_name(cls, value: str) -> str:
292+
if VALID_PERSONA_NAME_PATTERN.match(value):
293+
return value
294+
raise ValueError(
295+
"Invalid persona name. It should be alphanumeric with underscores and dashes."
296+
)
297+
279298

280299
class PersonaEmbedding(Persona):
281300
"""

src/codegate/muxing/persona.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unicodedata
22
import uuid
3-
from typing import Optional
3+
from typing import List, Optional
44

55
import numpy as np
66
import regex as re
@@ -165,6 +165,21 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None:
165165
await self._db_recorder.add_persona(new_persona)
166166
logger.info(f"Added persona {persona_name} to the database.")
167167

168+
async def get_persona(self, persona_name: str) -> db_models.Persona:
169+
"""
170+
Get a persona from the database by name.
171+
"""
172+
persona = await self._db_reader.get_persona_by_name(persona_name)
173+
if not persona:
174+
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
175+
return persona
176+
177+
async def get_all_personas(self) -> List[db_models.Persona]:
178+
"""
179+
Get all personas from the database.
180+
"""
181+
return await self._db_reader.get_all_personas()
182+
168183
async def update_persona(
169184
self, persona_name: str, new_persona_name: str, new_persona_desc: str
170185
) -> None:

tests/muxing/test_persona.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List
44

55
import pytest
6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, ValidationError
77

88
from codegate.db import connection
99
from codegate.muxing.persona import (
@@ -54,7 +54,7 @@ async def test_add_persona(semantic_router_mocked_db: PersonaManager):
5454
persona_name = "test_persona"
5555
persona_desc = "test_persona_desc"
5656
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
57-
retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name)
57+
retrieved_persona = await semantic_router_mocked_db.get_persona(persona_name)
5858
assert retrieved_persona.name == persona_name
5959
assert retrieved_persona.description == persona_desc
6060

@@ -72,6 +72,18 @@ async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager):
7272
await semantic_router_mocked_db.add_persona(persona_name, updated_description)
7373

7474

75+
@pytest.mark.asyncio
76+
async def test_add_persona_invalid_name(semantic_router_mocked_db: PersonaManager):
77+
"""Test adding a persona to the database."""
78+
persona_name = "test_persona&"
79+
persona_desc = "test_persona_desc"
80+
with pytest.raises(ValidationError):
81+
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
82+
83+
with pytest.raises(PersonaDoesNotExistError):
84+
await semantic_router_mocked_db.delete_persona(persona_name)
85+
86+
7587
@pytest.mark.asyncio
7688
async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager):
7789
"""Test checking persona match when persona does not exist"""
@@ -235,7 +247,7 @@ class PersonaMatchTest(BaseModel):
235247

236248
# DevOps/SRE Engineer Persona
237249
devops_sre = PersonaMatchTest(
238-
persona_name="devops/sre engineer",
250+
persona_name="devops sre engineer",
239251
persona_desc="""
240252
Expert in infrastructure automation, deployment pipelines, and operational reliability.
241253
Specializes in building and maintaining scalable, resilient, and secure infrastructure.
@@ -441,8 +453,8 @@ async def test_delete_persona(semantic_router_mocked_db: PersonaManager):
441453

442454
await semantic_router_mocked_db.delete_persona(persona_name)
443455

444-
persona_found = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name)
445-
assert persona_found is None
456+
with pytest.raises(PersonaDoesNotExistError):
457+
await semantic_router_mocked_db.get_persona(persona_name)
446458

447459

448460
@pytest.mark.asyncio
@@ -451,3 +463,28 @@ async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManag
451463

452464
with pytest.raises(PersonaDoesNotExistError):
453465
await semantic_router_mocked_db.delete_persona(persona_name)
466+
467+
468+
@pytest.mark.asyncio
469+
async def test_get_personas(semantic_router_mocked_db: PersonaManager):
470+
"""Test getting personas from the database."""
471+
persona_name = "test_persona"
472+
persona_desc = "test_persona_desc"
473+
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
474+
475+
persona_name_2 = "test_persona_2"
476+
persona_desc_2 = "foo and bar"
477+
await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2)
478+
479+
all_personas = await semantic_router_mocked_db.get_all_personas()
480+
assert len(all_personas) == 2
481+
assert all_personas[0].name == persona_name
482+
assert all_personas[1].name == persona_name_2
483+
484+
485+
@pytest.mark.asyncio
486+
async def test_get_personas_empty(semantic_router_mocked_db: PersonaManager):
487+
"""Test adding a persona to the database."""
488+
489+
all_personas = await semantic_router_mocked_db.get_all_personas()
490+
assert len(all_personas) == 0

0 commit comments

Comments
 (0)