From a5a5b73c8e68b772865f86f0e2040207ce3af619 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 17 Dec 2024 11:46:50 +0100 Subject: [PATCH] Refine caching non-copilot FIM calls Closes: #376 We were using the whole content of the prompt to hash the requests that doesn't come from copilot. This is inefficient since the prompt between requests can vary quite a lot. Instead use the filepath included in every request. After some investigation, `copilot` puts the filepath at the top of the prompt while the rest of the providers include the filepath at the bottom of the context section. --- src/codegate/db/connection.py | 29 +++++++++++++++++----------- tests/db/test_connection.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 tests/db/test_connection.py diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 521cef7d..f4c0c064 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -2,6 +2,7 @@ import hashlib import json import re +from datetime import timedelta from pathlib import Path from typing import List, Optional @@ -200,20 +201,23 @@ def _extract_request_message(self, request: str) -> Optional[dict]: def _create_hash_key(self, message: str, provider: str) -> str: """Creates a hash key from the message and includes the provider""" - # Try to extract the path from the message. Most of the times is at the top of the message. - # The pattern was generated using ChatGPT. Should match common occurrences like: + # Try to extract the path from the FIM message. The path is in FIM request in these formats: # folder/testing_file.py # Path: file3.py - pattern = r"(?:[a-zA-Z]:\\|\/)?(?:[^\s\/]+\/)*[^\s\/]+\.[^\s\/]+" - match = re.search(pattern, message) - # Copilot it's the only provider that has an easy path to extract. - # Other providers are harder to extact. This part needs to be revisited for the moment - # hasing the entire request message. - if match is None or provider != "copilot": - logger.warning("No path found in message or not copilot. Creating hash from message.") + pattern = r"^#.*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b" + matches = re.findall(pattern, message, re.MULTILINE) + # If no path is found, hash the entire prompt message. + if not matches: + logger.warning("No path found in messages. Creating hash cache from message.") message_to_hash = f"{message}-{provider}" else: - message_to_hash = f"{match.group(0)}-{provider}" + # Copilot puts the path at the top of the file. Continue providers contain + # several paths, the one in which the fim is triggered is the last one. + if provider == "copilot": + filepath = matches[0] + else: + filepath = matches[-1] + message_to_hash = f"{filepath}-{provider}" logger.debug(f"Message to hash: {message_to_hash}") hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest() @@ -247,7 +251,10 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> bool: elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds() if elapsed_seconds < Config.get_config().max_fim_hash_lifetime: - logger.info(f"Skipping context recording. Elapsed time: {elapsed_seconds} seconds.") + logger.info( + f"Skipping DB context recording. " + f"Elapsed time since last FIM cache: {timedelta(seconds=elapsed_seconds)}." + ) return False async def record_context(self, context: Optional[PipelineContext]) -> None: diff --git a/tests/db/test_connection.py b/tests/db/test_connection.py new file mode 100644 index 00000000..35a8f60e --- /dev/null +++ b/tests/db/test_connection.py @@ -0,0 +1,36 @@ +import hashlib +from unittest.mock import patch + +import pytest + +from codegate.db.connection import DbRecorder + + +@patch("codegate.db.connection.DbRecorder.__init__", return_value=None) +def mock_db_recorder(mocked_init) -> DbRecorder: + db_recorder = DbRecorder() + return db_recorder + + +fim_message = """ +# Path: folder/testing_file.py +# another_folder/another_file.py + +This is a test message +""" + + +@pytest.mark.parametrize( + "message, provider, expected_message_to_hash", + [ + ("This is a test message", "test_provider", "This is a test message-test_provider"), + (fim_message, "copilot", "folder/testing_file.py-copilot"), + (fim_message, "other", "another_folder/another_file.py-other"), + ], +) +def test_create_hash_key(message, provider, expected_message_to_hash): + mocked_db_recorder = mock_db_recorder() + expected_hash = hashlib.sha256(expected_message_to_hash.encode("utf-8")).hexdigest() + + result_hash = mocked_db_recorder._create_hash_key(message, provider) + assert result_hash == expected_hash