Skip to content

Created necessary methods for Persona CRUD #1232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.db.models import AlertSeverity, WorkspaceWithModel
from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
from codegate.muxing.persona import (
PersonaDoesNotExistError,
PersonaManager,
PersonaSimilarDescriptionError,
)
from codegate.providers import crud as provendcrud
from codegate.workspaces import crud

Expand All @@ -21,6 +26,7 @@
v1 = APIRouter()
wscrud = crud.WorkspaceCrud()
pcrud = provendcrud.ProviderCrud()
persona_manager = PersonaManager()

# This is a singleton object
dbreader = DbReader()
Expand Down Expand Up @@ -665,3 +671,89 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
except Exception:
logger.exception("Error while getting messages")
raise HTTPException(status_code=500, detail="Internal server error")


@v1.get("/personas", tags=["Personas"], generate_unique_id_function=uniq_name)
async def list_personas() -> List[Persona]:
"""List all personas."""
try:
personas = await dbreader.get_all_personas()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not leverage the persona manager here to keep the code consistent? Adding the dbreader here seems like an antipattern

return personas
except Exception:
logger.exception("Error while getting personas")
raise HTTPException(status_code=500, detail="Internal server error")


@v1.get("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use an annotated object to do input validation on persona_name? I'd like to restrict the characters that folks can use here to prevent vulnerabilities.

async def get_persona(persona_name: str) -> Persona:
"""Get a persona by name."""
try:
persona = await dbreader.get_persona_by_name(persona_name)
if not persona:
raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found")
return persona
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be catching the exception when a persona doesn't exist?

except Exception as e:
if isinstance(e, HTTPException):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the db reader be raising an HTTP exception? This sounds off

raise e
logger.exception(f"Error while getting persona {persona_name}")
raise HTTPException(status_code=500, detail="Internal server error")


@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201)
async def create_persona(request: v1_models.PersonaRequest) -> Persona:
"""Create a new persona."""
try:
await persona_manager.add_persona(request.name, request.description)
persona = await dbreader.get_persona_by_name(request.name)
return persona
except PersonaSimilarDescriptionError:
logger.exception("Error while creating persona")
raise HTTPException(status_code=409, detail="Persona has a similar description to another")
except AlreadyExistsError:
logger.exception("Error while creating persona")
raise HTTPException(status_code=409, detail="Persona already exists")
except Exception:
logger.exception("Error while creating persona")
raise HTTPException(status_code=500, detail="Internal server error")


@v1.put("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequest) -> Persona:
"""Update an existing persona."""
try:
await persona_manager.update_persona(
persona_name, request.new_name, request.new_description
)
persona = await dbreader.get_persona_by_name(request.new_name)
return persona
except PersonaSimilarDescriptionError:
logger.exception("Error while updating persona")
raise HTTPException(status_code=409, detail="Persona has a similar description to another")
except PersonaDoesNotExistError:
logger.exception("Error while updating persona")
raise HTTPException(status_code=404, detail="Persona does not exist")
except AlreadyExistsError:
logger.exception("Error while updating persona")
raise HTTPException(status_code=409, detail="Persona already exists")
except Exception:
logger.exception("Error while updating persona")
raise HTTPException(status_code=500, detail="Internal server error")


@v1.delete(
"/personas/{persona_name}",
tags=["Personas"],
generate_unique_id_function=uniq_name,
status_code=204,
)
async def delete_persona(persona_name: str):
"""Delete a persona."""
try:
await persona_manager.delete_persona(persona_name)
return Response(status_code=204)
except PersonaDoesNotExistError:
logger.exception("Error while updating persona")
raise HTTPException(status_code=404, detail="Persona does not exist")
except Exception:
logger.exception("Error while deleting persona")
raise HTTPException(status_code=500, detail="Internal server error")
18 changes: 18 additions & 0 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,21 @@ class ModelByProvider(pydantic.BaseModel):

def __str__(self):
return f"{self.provider_name} / {self.name}"


class PersonaRequest(pydantic.BaseModel):
"""
Model for creating a new Persona.
"""

name: str
description: str


class PersonaUpdateRequest(pydantic.BaseModel):
"""
Model for updating a Persona.
"""

new_name: str
new_description: str
67 changes: 60 additions & 7 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,15 +561,41 @@ async def add_persona(self, persona: PersonaEmbedding) -> None:
)

try:
# For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
# We need to convert it back to a numpy array before inserting it into the DB.
persona_dict = persona.model_dump()
persona_dict["description_embedding"] = persona.description_embedding
await self._execute_with_no_return(sql, persona_dict)
await self._execute_with_no_return(sql, persona.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")

async def update_persona(self, persona: PersonaEmbedding) -> None:
"""
Update an existing Persona in the DB.

This handles validation and update of an existing persona.
"""
sql = text(
"""
UPDATE personas
SET name = :name,
description = :description,
description_embedding = :description_embedding
WHERE id = :id
"""
)

try:
await self._execute_with_no_return(sql, persona.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")

async def delete_persona(self, persona_id: str) -> None:
"""
Delete an existing Persona from the DB.
"""
sql = text("DELETE FROM personas WHERE id = :id")
conditions = {"id": persona_id}
await self._execute_with_no_return(sql, conditions)


class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
Expand All @@ -588,14 +614,20 @@ async def _dump_result_to_pydantic_model(
return None

async def _execute_select_pydantic_model(
self, model_type: Type[BaseModel], sql_command: TextClause
self,
model_type: Type[BaseModel],
sql_command: TextClause,
should_raise: bool = False,
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command)
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model: {model_type}.", error=str(e))
# Exposes errors to the caller
if should_raise:
raise e
return None

async def _exec_select_conditions_to_pydantic(
Expand Down Expand Up @@ -1005,7 +1037,7 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
return personas[0] if personas else None

async def get_distance_to_existing_personas(
self, query_embedding: np.ndarray
self, query_embedding: np.ndarray, exclude_id: Optional[str]
) -> List[PersonaDistance]:
"""
Get the distance between a persona and a query embedding.
Expand All @@ -1019,6 +1051,13 @@ async def get_distance_to_existing_personas(
FROM personas
"""
conditions = {"query_embedding": query_embedding}

# Exclude this persona from the SQL query. Used when checking the descriptions
# for updating the persona. Exclude the persona to update itself from the query.
if exclude_id:
sql += " WHERE id != :exclude_id"
conditions["exclude_id"] = exclude_id

persona_distances = await self._exec_vec_db_query_to_pydantic(
sql, conditions, PersonaDistance
)
Expand All @@ -1045,6 +1084,20 @@ async def get_distance_to_persona(
)
return persona_distance[0]

async def get_all_personas(self) -> List[Persona]:
"""
Get all the personas.
"""
sql = text(
"""
SELECT
id, name, description
FROM personas
"""
)
personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True)
return personas


class DbTransaction:
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def nd_array_custom_before_validator(x):

def nd_array_custom_serializer(x):
# custome serialization logic
return str(x)
return x


# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unicodedata
import uuid
from typing import Optional

import numpy as np
import regex as re
Expand Down Expand Up @@ -32,11 +33,12 @@ class PersonaSimilarDescriptionError(Exception):
pass


class SemanticRouter:
class PersonaManager:

def __init__(self):
self._inference_engine = LlamaCppInferenceEngine()
Config.load()
conf = Config.get_config()
self._inference_engine = LlamaCppInferenceEngine()
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
Expand Down Expand Up @@ -110,13 +112,15 @@ async def _embed_text(self, text: str) -> np.ndarray:
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
return np.array(embed_list[0], dtype=np.float32)

async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bool:
async def _is_persona_description_diff(
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
) -> bool:
"""
Check if the persona description is different enough from existing personas.
"""
# The distance calculation is done in the database
persona_distances = await self._db_reader.get_distance_to_existing_personas(
emb_persona_desc
emb_persona_desc, exclude_id
)
if not persona_distances:
return True
Expand All @@ -131,16 +135,26 @@ async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bo
return False
return True

async def add_persona(self, persona_name: str, persona_desc: str) -> None:
async def _validate_persona_description(
self, persona_desc: str, exclude_id: str = None
) -> np.ndarray:
"""
Add a new persona to the database. The persona description is embedded
and stored in the database.
Validate the persona description by embedding the text and checking if it is
different enough from existing personas.
"""
emb_persona_desc = await self._embed_text(persona_desc)
if not await self._is_persona_description_diff(emb_persona_desc):
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
raise PersonaSimilarDescriptionError(
"The persona description is too similar to existing personas."
)
return emb_persona_desc

async def add_persona(self, persona_name: str, persona_desc: str) -> None:
"""
Add a new persona to the database. The persona description is embedded
and stored in the database.
"""
emb_persona_desc = await self._validate_persona_description(persona_desc)

new_persona = db_models.PersonaEmbedding(
id=str(uuid.uuid4()),
Expand All @@ -151,6 +165,43 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None:
await self._db_recorder.add_persona(new_persona)
logger.info(f"Added persona {persona_name} to the database.")

async def update_persona(
self, persona_name: str, new_persona_name: str, new_persona_desc: str
) -> None:
"""
Update an existing persona in the database. The name and description are
updated in the database, but the ID remains the same.
"""
# First we check if the persona exists, if not we raise an error
found_persona = await self._db_reader.get_persona_by_name(persona_name)
if not found_persona:
raise PersonaDoesNotExistError(f"Person {persona_name} does not exist.")

emb_persona_desc = await self._validate_persona_description(
new_persona_desc, exclude_id=found_persona.id
)

# Then we update the attributes in the database
updated_persona = db_models.PersonaEmbedding(
id=found_persona.id,
name=new_persona_name,
description=new_persona_desc,
description_embedding=emb_persona_desc,
)
await self._db_recorder.update_persona(updated_persona)
logger.info(f"Updated persona {persona_name} in the database.")

async def delete_persona(self, persona_name: str) -> None:
"""
Delete a persona from the database.
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")

await self._db_recorder.delete_persona(persona.id)
logger.info(f"Deleted persona {persona_name} from the database.")

async def check_persona_match(self, persona_name: str, query: str) -> bool:
"""
Check if the query matches the persona description. A vector similarity
Expand Down
Loading