Skip to content

Commit 51cf196

Browse files
committed
feat: improve cache system to collect the last output
Due to the cache system, we were collecting only the initial output of fim, that was including incomplete output. Add an update method to the cache, so we can collect all the output that comes from fim, associated to the same request Closes: #472
1 parent 8b95d7f commit 51cf196

File tree

4 files changed

+129
-54
lines changed

4 files changed

+129
-54
lines changed

src/codegate/dashboard/dashboard.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import AsyncGenerator, List
2+
from typing import AsyncGenerator, List, Optional
33

44
import structlog
55
from fastapi import APIRouter, Depends
@@ -36,7 +36,7 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat
3636

3737

3838
@dashboard_router.get("/dashboard/alerts")
39-
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[AlertConversation]:
39+
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
4040
"""
4141
Get all the messages from the database and return them as a list of conversations.
4242
"""

src/codegate/db/connection.py

+69-30
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import structlog
77
from pydantic import BaseModel
8-
from sqlalchemy import text
8+
from sqlalchemy import TextClause, text
99
from sqlalchemy.ext.asyncio import create_async_engine
1010

1111
from codegate.db.fim_cache import FimCache
@@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None):
3030
current_dir = Path(__file__).parent
3131
sqlite_path = (
3232
current_dir.parent.parent.parent / "codegate_volume" / "db" / "codegate.db"
33-
)
34-
self._db_path = Path(sqlite_path).absolute()
33+
) # type: ignore
34+
self._db_path = Path(sqlite_path).absolute() # type: ignore
3535
self._db_path.parent.mkdir(parents=True, exist_ok=True)
3636
logger.debug(f"Initializing DB from path: {self._db_path}")
3737
engine_dict = {
@@ -82,15 +82,15 @@ async def init_db(self):
8282
finally:
8383
await self._async_db_engine.dispose()
8484

85-
async def _insert_pydantic_model(
86-
self, model: BaseModel, sql_insert: text
85+
async def _execute_update_pydantic_model(
86+
self, model: BaseModel, sql_command: TextClause #
8787
) -> Optional[BaseModel]:
8888
# There are create method in queries.py automatically generated by sqlc
8989
# However, the methods are buggy for Pydancti and don't work as expected.
9090
# Manually writing the SQL query to insert Pydantic models.
9191
async with self._async_db_engine.begin() as conn:
9292
try:
93-
result = await conn.execute(sql_insert, model.model_dump())
93+
result = await conn.execute(sql_command, model.model_dump())
9494
row = result.first()
9595
if row is None:
9696
return None
@@ -99,7 +99,7 @@ async def _insert_pydantic_model(
9999
model_class = model.__class__
100100
return model_class(**row._asdict())
101101
except Exception as e:
102-
logger.error(f"Failed to insert model: {model}.", error=str(e))
102+
logger.error(f"Failed to update model: {model}.", error=str(e))
103103
return None
104104

105105
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
@@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
112112
RETURNING *
113113
"""
114114
)
115-
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
115+
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
116116
# Uncomment to debug the recorded request
117117
# logger.debug(f"Recorded request: {recorded_request}")
118-
return recorded_request
118+
return recorded_request # type: ignore
119119

120-
async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
120+
async def update_request(self, initial_id: str,
121+
prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
122+
if prompt_params is None:
123+
return None
124+
prompt_params.id = initial_id # overwrite the initial id of the request
125+
sql = text(
126+
"""
127+
UPDATE prompts
128+
SET timestamp = :timestamp, provider = :provider, request = :request, type = :type
129+
WHERE id = :id
130+
RETURNING *
131+
"""
132+
)
133+
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
134+
# Uncomment to debug the recorded request
135+
# logger.debug(f"Recorded request: {recorded_request}")
136+
return updated_request # type: ignore
137+
138+
async def record_outputs(self, outputs: List[Output],
139+
initial_id: Optional[str]) -> Optional[Output]:
121140
if not outputs:
122141
return
123142

124143
first_output = outputs[0]
125144
# Create a single entry on DB but encode all of the chunks in the stream as a list
126145
# of JSON objects in the field `output`
146+
if initial_id:
147+
first_output.prompt_id = initial_id
127148
output_db = Output(
128149
id=first_output.id,
129150
prompt_id=first_output.prompt_id,
@@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
143164
RETURNING *
144165
"""
145166
)
146-
recorded_output = await self._insert_pydantic_model(output_db, sql)
167+
recorded_output = await self._execute_update_pydantic_model(output_db, sql)
147168
# Uncomment to debug
148169
# logger.debug(f"Recorded output: {recorded_output}")
149-
return recorded_output
170+
return recorded_output # type: ignore
150171

151-
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
172+
async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> List[Alert]:
152173
if not alerts:
153-
return
174+
return []
154175
sql = text(
155176
"""
156177
INSERT INTO alerts (
@@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
167188
async with asyncio.TaskGroup() as tg:
168189
for alert in alerts:
169190
try:
170-
result = tg.create_task(self._insert_pydantic_model(alert, sql))
191+
if initial_id:
192+
alert.prompt_id = initial_id
193+
result = tg.create_task(self._execute_update_pydantic_model(alert, sql))
171194
alerts_tasks.append(result)
172195
except Exception as e:
173196
logger.error(f"Failed to record alert: {alert}.", error=str(e))
@@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
182205
# logger.debug(f"Recorded alerts: {recorded_alerts}")
183206
return recorded_alerts
184207

185-
def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
186-
"""Check if the context should be recorded in DB"""
208+
def _should_record_context(self, context: Optional[PipelineContext]) -> tuple:
209+
"""Check if the context should be recorded in DB and determine the action."""
187210
if context is None or context.metadata.get("stored_in_db", False):
188-
return False
211+
return False, None, None
189212

190213
if not context.input_request:
191214
logger.warning("No input request found. Skipping recording context.")
192-
return False
215+
return False, None, None
193216

194217
# If it's not a FIM prompt, we don't need to check anything else.
195218
if context.input_request.type != "fim":
196-
return True
219+
return True, 'add', '' # Default to add if not FIM, since no cache check is required
197220

198-
return fim_cache.could_store_fim_request(context)
221+
return fim_cache.could_store_fim_request(context) # type: ignore
199222

200223
async def record_context(self, context: Optional[PipelineContext]) -> None:
201224
try:
202-
if not self._should_record_context(context):
225+
if not context:
226+
logger.info("No context provided, skipping")
203227
return
204-
await self.record_request(context.input_request)
205-
await self.record_outputs(context.output_responses)
206-
await self.record_alerts(context.alerts_raised)
207-
context.metadata["stored_in_db"] = True
208-
logger.info(
209-
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
210-
f"Alerts: {len(context.alerts_raised)}."
211-
)
228+
should_record, action, initial_id = self._should_record_context(context)
229+
if not should_record:
230+
logger.info("Skipping record of context, not needed")
231+
return
232+
if action == 'add':
233+
await self.record_request(context.input_request)
234+
await self.record_outputs(context.output_responses, None)
235+
await self.record_alerts(context.alerts_raised, None)
236+
context.metadata["stored_in_db"] = True
237+
logger.info(
238+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
239+
f"Alerts: {len(context.alerts_raised)}."
240+
)
241+
else:
242+
# update them
243+
await self.update_request(initial_id, context.input_request)
244+
await self.record_outputs(context.output_responses, initial_id)
245+
await self.record_alerts(context.alerts_raised, initial_id)
246+
context.metadata["stored_in_db"] = True
247+
logger.info(
248+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
249+
f"Alerts: {len(context.alerts_raised)}."
250+
)
212251
except Exception as e:
213252
logger.error(f"Failed to record context: {context}.", error=str(e))
214253

src/codegate/db/fim_cache.py

+51-18
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CachedFim(BaseModel):
1818

1919
timestamp: datetime.datetime
2020
critical_alerts: List[Alert]
21+
initial_id: str
2122

2223

2324
class FimCache:
@@ -86,16 +87,42 @@ def _calculate_hash_key(self, message: str, provider: str) -> str:
8687

8788
def _add_cache_entry(self, hash_key: str, context: PipelineContext):
8889
"""Add a new cache entry"""
90+
if not context.input_request:
91+
logger.warning("No input request found. Skipping creating a mapping entry")
92+
return
8993
critical_alerts = [
9094
alert
9195
for alert in context.alerts_raised
9296
if alert.trigger_category == AlertSeverity.CRITICAL.value
9397
]
9498
new_cache = CachedFim(
95-
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
99+
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts,
100+
initial_id=context.input_request.id
96101
)
97102
self.cache[hash_key] = new_cache
98103
logger.info(f"Added cache entry for hash key: {hash_key}")
104+
return self.cache[hash_key]
105+
106+
def _update_cache_entry(self, hash_key: str, context: PipelineContext):
107+
"""Update an existing cache entry without changing the timestamp."""
108+
existing_entry = self.cache.get(hash_key)
109+
if existing_entry is not None:
110+
# Update critical alerts while retaining the original timestamp.
111+
critical_alerts = [
112+
alert
113+
for alert in context.alerts_raised
114+
if alert.trigger_category == AlertSeverity.CRITICAL.value
115+
]
116+
# Update the entry in the cache with new critical alerts but keep the old timestamp.
117+
updated_cache = CachedFim(
118+
timestamp=existing_entry.timestamp, critical_alerts=critical_alerts,
119+
initial_id=existing_entry.initial_id
120+
)
121+
self.cache[hash_key] = updated_cache
122+
logger.info(f"Updated cache entry for hash key: {hash_key}")
123+
else:
124+
# Log a warning if trying to update a non-existent entry - ideally should not happen.
125+
logger.warning(f"Attempted to update non-existent cache entry for hash key: {hash_key}")
99126

100127
def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
101128
"""Check if there are new alerts present"""
@@ -108,29 +135,35 @@ def _are_new_alerts_present(self, context: PipelineContext, cached_entry: Cached
108135

109136
def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
110137
"""Check if the cached entry is old"""
138+
if not context.input_request:
139+
logger.warning("No input request found. Skipping checking if the cache entry is old")
140+
return False
111141
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
112-
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime
142+
config = Config.get_config()
143+
if config is None:
144+
logger.warning("No configuration found. Skipping checking if the cache entry is old")
145+
return True
146+
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime # type: ignore
113147

114148
def could_store_fim_request(self, context: PipelineContext):
149+
if not context.input_request:
150+
logger.warning("No input request found. Skipping creating a mapping entry")
151+
return False, '', ''
115152
# Couldn't process the user message. Skip creating a mapping entry.
116153
message = self._extract_message_from_fim_request(context.input_request.request)
117154
if message is None:
118155
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
119-
return False
156+
return False, '', ''
120157

121-
hash_key = self._calculate_hash_key(message, context.input_request.provider)
158+
hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore
122159
cached_entry = self.cache.get(hash_key, None)
123-
if cached_entry is None:
124-
self._add_cache_entry(hash_key, context)
125-
return True
126-
127-
if self._is_cached_entry_old(context, cached_entry):
128-
self._add_cache_entry(hash_key, context)
129-
return True
130-
131-
if self._are_new_alerts_present(context, cached_entry):
132-
self._add_cache_entry(hash_key, context)
133-
return True
134-
135-
logger.debug(f"FIM entry already in cache: {hash_key}.")
136-
return False
160+
if cached_entry is None or self._is_cached_entry_old(
161+
context, cached_entry) or self._are_new_alerts_present(context, cached_entry):
162+
cached_entry = self._add_cache_entry(hash_key, context)
163+
if cached_entry is None:
164+
logger.warning("Failed to add cache entry")
165+
return False, '', ''
166+
return True, 'add', cached_entry.initial_id
167+
168+
self._update_cache_entry(hash_key, context)
169+
return True, 'update', cached_entry.initial_id

tests/db/test_fim_cache.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_extract_message_from_fim_request(test_request, expected_result_content)
127127

128128
def test_are_new_alerts_present():
129129
fim_cache = FimCache()
130-
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[])
130+
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[], initial_id="1")
131131
context = PipelineContext()
132132
context.alerts_raised = [mock.MagicMock(trigger_category=AlertSeverity.CRITICAL.value)]
133133
result = fim_cache._are_new_alerts_present(context, cached_entry)
@@ -146,6 +146,7 @@ def test_are_new_alerts_present():
146146
trigger_string=None,
147147
)
148148
],
149+
initial_id='2'
149150
)
150151
result = fim_cache._are_new_alerts_present(context, populated_cache)
151152
assert result is False
@@ -155,15 +156,17 @@ def test_are_new_alerts_present():
155156
"cached_entry, is_old",
156157
[
157158
(
158-
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), critical_alerts=[]),
159+
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1),
160+
critical_alerts=[], initial_id='1'),
159161
True,
160162
),
161-
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[]), False),
163+
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[],
164+
initial_id='2'), False),
162165
],
163166
)
164167
def test_is_cached_entry_old(cached_entry, is_old):
165168
context = PipelineContext()
166-
context.add_input_request("test", True, "test_provider")
169+
context.add_input_request("test", True, "test_provider") # type: ignore
167170
fim_cache = FimCache()
168171
result = fim_cache._is_cached_entry_old(context, cached_entry)
169172
assert result == is_old

0 commit comments

Comments
 (0)