Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Refine caching non-copilot FIM calls #392

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import json
import re
from datetime import timedelta
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -200,20 +201,23 @@ def _extract_request_message(self, request: str) -> Optional[dict]:

def _create_hash_key(self, message: str, provider: str) -> str:
"""Creates a hash key from the message and includes the provider"""
# Try to extract the path from the message. Most of the times is at the top of the message.
# The pattern was generated using ChatGPT. Should match common occurrences like:
# Try to extract the path from the FIM message. The path is in FIM request in these formats:
# folder/testing_file.py
# Path: file3.py
pattern = r"(?:[a-zA-Z]:\\|\/)?(?:[^\s\/]+\/)*[^\s\/]+\.[^\s\/]+"
match = re.search(pattern, message)
# Copilot it's the only provider that has an easy path to extract.
# Other providers are harder to extact. This part needs to be revisited for the moment
# hasing the entire request message.
if match is None or provider != "copilot":
logger.warning("No path found in message or not copilot. Creating hash from message.")
pattern = r"^#.*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
matches = re.findall(pattern, message, re.MULTILINE)
# If no path is found, hash the entire prompt message.
if not matches:
logger.warning("No path found in messages. Creating hash cache from message.")
message_to_hash = f"{message}-{provider}"
else:
message_to_hash = f"{match.group(0)}-{provider}"
# Copilot puts the path at the top of the file. Continue providers contain
# several paths, the one in which the fim is triggered is the last one.
if provider == "copilot":
filepath = matches[0]
else:
filepath = matches[-1]
message_to_hash = f"{filepath}-{provider}"

logger.debug(f"Message to hash: {message_to_hash}")
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -247,7 +251,10 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> bool:

elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds()
if elapsed_seconds < Config.get_config().max_fim_hash_lifetime:
logger.info(f"Skipping context recording. Elapsed time: {elapsed_seconds} seconds.")
logger.info(
f"Skipping DB context recording. "
f"Elapsed time since last FIM cache: {timedelta(seconds=elapsed_seconds)}."
)
return False

async def record_context(self, context: Optional[PipelineContext]) -> None:
Expand Down
36 changes: 36 additions & 0 deletions tests/db/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import hashlib
from unittest.mock import patch

import pytest

from codegate.db.connection import DbRecorder


@patch("codegate.db.connection.DbRecorder.__init__", return_value=None)
def mock_db_recorder(mocked_init) -> DbRecorder:
db_recorder = DbRecorder()
return db_recorder


fim_message = """
# Path: folder/testing_file.py
# another_folder/another_file.py

This is a test message
"""


@pytest.mark.parametrize(
"message, provider, expected_message_to_hash",
[
("This is a test message", "test_provider", "This is a test message-test_provider"),
(fim_message, "copilot", "folder/testing_file.py-copilot"),
(fim_message, "other", "another_folder/another_file.py-other"),
],
)
def test_create_hash_key(message, provider, expected_message_to_hash):
mocked_db_recorder = mock_db_recorder()
expected_hash = hashlib.sha256(expected_message_to_hash.encode("utf-8")).hexdigest()

result_hash = mocked_db_recorder._create_hash_key(message, provider)
assert result_hash == expected_hash
Loading