diff --git a/src/codegate/dashboard/dashboard.py b/src/codegate/dashboard/dashboard.py index 4ed39f34..19352b51 100644 --- a/src/codegate/dashboard/dashboard.py +++ b/src/codegate/dashboard/dashboard.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, List +from typing import AsyncGenerator, List, Optional import structlog from fastapi import APIRouter, Depends @@ -36,7 +36,7 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat @dashboard_router.get("/dashboard/alerts") -def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[AlertConversation]: +def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]: """ Get all the messages from the database and return them as a list of conversations. """ diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index c8fb60d0..af7c3b98 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -5,7 +5,7 @@ import structlog from pydantic import BaseModel -from sqlalchemy import text +from sqlalchemy import TextClause, text from sqlalchemy.ext.asyncio import create_async_engine from codegate.db.fim_cache import FimCache @@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None): current_dir = Path(__file__).parent sqlite_path = ( current_dir.parent.parent.parent / "codegate_volume" / "db" / "codegate.db" - ) - self._db_path = Path(sqlite_path).absolute() + ) # type: ignore + self._db_path = Path(sqlite_path).absolute() # type: ignore self._db_path.parent.mkdir(parents=True, exist_ok=True) logger.debug(f"Initializing DB from path: {self._db_path}") engine_dict = { @@ -82,15 +82,15 @@ async def init_db(self): finally: await self._async_db_engine.dispose() - async def _insert_pydantic_model( - self, model: BaseModel, sql_insert: text + async def _execute_update_pydantic_model( + self, model: BaseModel, sql_command: TextClause # ) -> Optional[BaseModel]: # There are create method in queries.py automatically generated by sqlc # However, the methods are buggy for Pydancti and don't work as expected. # Manually writing the SQL query to insert Pydantic models. async with self._async_db_engine.begin() as conn: try: - result = await conn.execute(sql_insert, model.model_dump()) + result = await conn.execute(sql_command, model.model_dump()) row = result.first() if row is None: return None @@ -99,7 +99,7 @@ async def _insert_pydantic_model( model_class = model.__class__ return model_class(**row._asdict()) except Exception as e: - logger.error(f"Failed to insert model: {model}.", error=str(e)) + logger.error(f"Failed to update model: {model}.", error=str(e)) return None async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: @@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option RETURNING * """ ) - recorded_request = await self._insert_pydantic_model(prompt_params, sql) + recorded_request = await self._execute_update_pydantic_model(prompt_params, sql) # Uncomment to debug the recorded request # logger.debug(f"Recorded request: {recorded_request}") - return recorded_request + return recorded_request # type: ignore - async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: + async def update_request(self, initial_id: str, + prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: + if prompt_params is None: + return None + prompt_params.id = initial_id # overwrite the initial id of the request + sql = text( + """ + UPDATE prompts + SET timestamp = :timestamp, provider = :provider, request = :request, type = :type + WHERE id = :id + RETURNING * + """ + ) + updated_request = await self._execute_update_pydantic_model(prompt_params, sql) + # Uncomment to debug the recorded request + # logger.debug(f"Recorded request: {recorded_request}") + return updated_request # type: ignore + + async def record_outputs(self, outputs: List[Output], + initial_id: Optional[str]) -> Optional[Output]: if not outputs: return first_output = outputs[0] # Create a single entry on DB but encode all of the chunks in the stream as a list # of JSON objects in the field `output` + if initial_id: + first_output.prompt_id = initial_id output_db = Output( id=first_output.id, prompt_id=first_output.prompt_id, @@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: RETURNING * """ ) - recorded_output = await self._insert_pydantic_model(output_db, sql) + recorded_output = await self._execute_update_pydantic_model(output_db, sql) # Uncomment to debug # logger.debug(f"Recorded output: {recorded_output}") - return recorded_output + return recorded_output # type: ignore - async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: + async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> List[Alert]: if not alerts: - return + return [] sql = text( """ INSERT INTO alerts ( @@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: async with asyncio.TaskGroup() as tg: for alert in alerts: try: - result = tg.create_task(self._insert_pydantic_model(alert, sql)) + if initial_id: + alert.prompt_id = initial_id + result = tg.create_task(self._execute_update_pydantic_model(alert, sql)) alerts_tasks.append(result) except Exception as e: logger.error(f"Failed to record alert: {alert}.", error=str(e)) @@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: # logger.debug(f"Recorded alerts: {recorded_alerts}") return recorded_alerts - def _should_record_context(self, context: Optional[PipelineContext]) -> bool: - """Check if the context should be recorded in DB""" + def _should_record_context(self, context: Optional[PipelineContext]) -> tuple: + """Check if the context should be recorded in DB and determine the action.""" if context is None or context.metadata.get("stored_in_db", False): - return False + return False, None, None if not context.input_request: logger.warning("No input request found. Skipping recording context.") - return False + return False, None, None # If it's not a FIM prompt, we don't need to check anything else. if context.input_request.type != "fim": - return True + return True, 'add', '' # Default to add if not FIM, since no cache check is required - return fim_cache.could_store_fim_request(context) + return fim_cache.could_store_fim_request(context) # type: ignore async def record_context(self, context: Optional[PipelineContext]) -> None: try: - if not self._should_record_context(context): + if not context: + logger.info("No context provided, skipping") return - await self.record_request(context.input_request) - await self.record_outputs(context.output_responses) - await self.record_alerts(context.alerts_raised) - context.metadata["stored_in_db"] = True - logger.info( - f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " - f"Alerts: {len(context.alerts_raised)}." - ) + should_record, action, initial_id = self._should_record_context(context) + if not should_record: + logger.info("Skipping record of context, not needed") + return + if action == 'add': + await self.record_request(context.input_request) + await self.record_outputs(context.output_responses, None) + await self.record_alerts(context.alerts_raised, None) + context.metadata["stored_in_db"] = True + logger.info( + f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) + else: + # update them + await self.update_request(initial_id, context.input_request) + await self.record_outputs(context.output_responses, initial_id) + await self.record_alerts(context.alerts_raised, initial_id) + context.metadata["stored_in_db"] = True + logger.info( + f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) except Exception as e: logger.error(f"Failed to record context: {context}.", error=str(e)) diff --git a/src/codegate/db/fim_cache.py b/src/codegate/db/fim_cache.py index 2a2d8761..e5a488b6 100644 --- a/src/codegate/db/fim_cache.py +++ b/src/codegate/db/fim_cache.py @@ -18,6 +18,7 @@ class CachedFim(BaseModel): timestamp: datetime.datetime critical_alerts: List[Alert] + initial_id: str class FimCache: @@ -86,16 +87,42 @@ def _calculate_hash_key(self, message: str, provider: str) -> str: def _add_cache_entry(self, hash_key: str, context: PipelineContext): """Add a new cache entry""" + if not context.input_request: + logger.warning("No input request found. Skipping creating a mapping entry") + return critical_alerts = [ alert for alert in context.alerts_raised if alert.trigger_category == AlertSeverity.CRITICAL.value ] new_cache = CachedFim( - timestamp=context.input_request.timestamp, critical_alerts=critical_alerts + timestamp=context.input_request.timestamp, critical_alerts=critical_alerts, + initial_id=context.input_request.id ) self.cache[hash_key] = new_cache logger.info(f"Added cache entry for hash key: {hash_key}") + return self.cache[hash_key] + + def _update_cache_entry(self, hash_key: str, context: PipelineContext): + """Update an existing cache entry without changing the timestamp.""" + existing_entry = self.cache.get(hash_key) + if existing_entry is not None: + # Update critical alerts while retaining the original timestamp. + critical_alerts = [ + alert + for alert in context.alerts_raised + if alert.trigger_category == AlertSeverity.CRITICAL.value + ] + # Update the entry in the cache with new critical alerts but keep the old timestamp. + updated_cache = CachedFim( + timestamp=existing_entry.timestamp, critical_alerts=critical_alerts, + initial_id=existing_entry.initial_id + ) + self.cache[hash_key] = updated_cache + logger.info(f"Updated cache entry for hash key: {hash_key}") + else: + # Log a warning if trying to update a non-existent entry - ideally should not happen. + logger.warning(f"Attempted to update non-existent cache entry for hash key: {hash_key}") def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool: """Check if there are new alerts present""" @@ -108,29 +135,35 @@ def _are_new_alerts_present(self, context: PipelineContext, cached_entry: Cached def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool: """Check if the cached entry is old""" + if not context.input_request: + logger.warning("No input request found. Skipping checking if the cache entry is old") + return False elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds() - return elapsed_seconds > Config.get_config().max_fim_hash_lifetime + config = Config.get_config() + if config is None: + logger.warning("No configuration found. Skipping checking if the cache entry is old") + return True + return elapsed_seconds > Config.get_config().max_fim_hash_lifetime # type: ignore def could_store_fim_request(self, context: PipelineContext): + if not context.input_request: + logger.warning("No input request found. Skipping creating a mapping entry") + return False, '', '' # Couldn't process the user message. Skip creating a mapping entry. message = self._extract_message_from_fim_request(context.input_request.request) if message is None: logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.") - return False + return False, '', '' - hash_key = self._calculate_hash_key(message, context.input_request.provider) + hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore cached_entry = self.cache.get(hash_key, None) - if cached_entry is None: - self._add_cache_entry(hash_key, context) - return True - - if self._is_cached_entry_old(context, cached_entry): - self._add_cache_entry(hash_key, context) - return True - - if self._are_new_alerts_present(context, cached_entry): - self._add_cache_entry(hash_key, context) - return True - - logger.debug(f"FIM entry already in cache: {hash_key}.") - return False + if cached_entry is None or self._is_cached_entry_old( + context, cached_entry) or self._are_new_alerts_present(context, cached_entry): + cached_entry = self._add_cache_entry(hash_key, context) + if cached_entry is None: + logger.warning("Failed to add cache entry") + return False, '', '' + return True, 'add', cached_entry.initial_id + + self._update_cache_entry(hash_key, context) + return True, 'update', cached_entry.initial_id diff --git a/tests/db/test_fim_cache.py b/tests/db/test_fim_cache.py index 6da4de5f..c6b5506e 100644 --- a/tests/db/test_fim_cache.py +++ b/tests/db/test_fim_cache.py @@ -127,7 +127,7 @@ def test_extract_message_from_fim_request(test_request, expected_result_content) def test_are_new_alerts_present(): fim_cache = FimCache() - cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[]) + cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[], initial_id="1") context = PipelineContext() context.alerts_raised = [mock.MagicMock(trigger_category=AlertSeverity.CRITICAL.value)] result = fim_cache._are_new_alerts_present(context, cached_entry) @@ -146,6 +146,7 @@ def test_are_new_alerts_present(): trigger_string=None, ) ], + initial_id='2' ) result = fim_cache._are_new_alerts_present(context, populated_cache) assert result is False @@ -155,15 +156,17 @@ def test_are_new_alerts_present(): "cached_entry, is_old", [ ( - CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), critical_alerts=[]), + CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), + critical_alerts=[], initial_id='1'), True, ), - (CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[]), False), + (CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[], + initial_id='2'), False), ], ) def test_is_cached_entry_old(cached_entry, is_old): context = PipelineContext() - context.add_input_request("test", True, "test_provider") + context.add_input_request("test", True, "test_provider") # type: ignore fim_cache = FimCache() result = fim_cache._is_cached_entry_old(context, cached_entry) assert result == is_old