Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit da69ec0

Browse files
Validate persona description is sufficiently different (#1225)
Closes: #1218 Check if the description for a new persona is different enough from the existing personas descriptions. This is done to correctly differentiate between personas
1 parent 81f0389 commit da69ec0

File tree

5 files changed

+250
-404
lines changed

5 files changed

+250
-404
lines changed

src/codegate/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,14 @@ class Config:
5757
force_certs: bool = False
5858

5959
max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
60+
6061
# Min value is 0 (max similarity), max value is 2 (orthogonal)
6162
# The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py
63+
# It's the threshold value to determine if a query matches a persona.
6264
persona_threshold = 0.75
65+
# The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
66+
# It's the threshold value to determine if a persona description is similar to existing personas
67+
persona_diff_desc_threshold = 0.3
6368

6469
# Provider URLs with defaults
6570
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

src/codegate/db/connection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,26 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
10041004
)
10051005
return personas[0] if personas else None
10061006

1007+
async def get_distance_to_existing_personas(
1008+
self, query_embedding: np.ndarray
1009+
) -> List[PersonaDistance]:
1010+
"""
1011+
Get the distance between a persona and a query embedding.
1012+
"""
1013+
sql = """
1014+
SELECT
1015+
id,
1016+
name,
1017+
description,
1018+
vec_distance_cosine(description_embedding, :query_embedding) as distance
1019+
FROM personas
1020+
"""
1021+
conditions = {"query_embedding": query_embedding}
1022+
persona_distances = await self._exec_vec_db_query_to_pydantic(
1023+
sql, conditions, PersonaDistance
1024+
)
1025+
return persona_distances
1026+
10071027
async def get_distance_to_persona(
10081028
self, persona_id: str, query_embedding: np.ndarray
10091029
) -> PersonaDistance:

src/codegate/db/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ class MuxRule(BaseModel):
245245

246246
def nd_array_custom_before_validator(x):
247247
# custome before validation logic
248+
if isinstance(x, bytes):
249+
return np.frombuffer(x, dtype=np.float32)
248250
return x
249251

250252

src/codegate/muxing/semantic_router.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class PersonaDoesNotExistError(Exception):
2828
pass
2929

3030

31+
class PersonaSimilarDescriptionError(Exception):
32+
pass
33+
34+
3135
class SemanticRouter:
3236

3337
def __init__(self):
@@ -36,6 +40,7 @@ def __init__(self):
3640
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
3741
self._n_gpu = conf.chat_model_n_gpu_layers
3842
self._persona_threshold = conf.persona_threshold
43+
self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
3944
self._db_recorder = DbRecorder()
4045
self._db_reader = DbReader()
4146

@@ -105,12 +110,38 @@ async def _embed_text(self, text: str) -> np.ndarray:
105110
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
106111
return np.array(embed_list[0], dtype=np.float32)
107112

113+
async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bool:
114+
"""
115+
Check if the persona description is different enough from existing personas.
116+
"""
117+
# The distance calculation is done in the database
118+
persona_distances = await self._db_reader.get_distance_to_existing_personas(
119+
emb_persona_desc
120+
)
121+
if not persona_distances:
122+
return True
123+
124+
for persona_distance in persona_distances:
125+
logger.info(
126+
f"Persona description distance to {persona_distance.name}",
127+
distance=persona_distance.distance,
128+
)
129+
# If the distance is less than the threshold, the persona description is too similar
130+
if persona_distance.distance < self._persona_diff_desc_threshold:
131+
return False
132+
return True
133+
108134
async def add_persona(self, persona_name: str, persona_desc: str) -> None:
109135
"""
110136
Add a new persona to the database. The persona description is embedded
111137
and stored in the database.
112138
"""
113139
emb_persona_desc = await self._embed_text(persona_desc)
140+
if not await self._is_persona_description_diff(emb_persona_desc):
141+
raise PersonaSimilarDescriptionError(
142+
"The persona description is too similar to existing personas."
143+
)
144+
114145
new_persona = db_models.PersonaEmbedding(
115146
id=str(uuid.uuid4()),
116147
name=persona_name,

0 commit comments

Comments
 (0)