Skip to content

Commit 9c57e92

Browse files
Merge pull request #408 from stacklok/cachce-non-python-fim-reqs
fix: FIM not caching correctly non-python files
2 parents 9b47715 + af52b43 commit 9c57e92

File tree

6 files changed

+331
-123
lines changed

6 files changed

+331
-123
lines changed

src/codegate/codegate_logging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def _missing_(cls, value: str) -> Optional["LogFormat"]:
5050

5151
def add_origin(logger, log_method, event_dict):
5252
# Add 'origin' if it's bound to the logger but not explicitly in the event dict
53-
if 'origin' not in event_dict and hasattr(logger, '_context'):
54-
origin = logger._context.get('origin')
53+
if "origin" not in event_dict and hasattr(logger, "_context"):
54+
origin = logger._context.get("origin")
5555
if origin:
56-
event_dict['origin'] = origin
56+
event_dict["origin"] = origin
5757
return event_dict
5858

5959

src/codegate/db/connection.py

Lines changed: 15 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import asyncio
2-
import hashlib
32
import json
4-
import re
5-
from datetime import timedelta
63
from pathlib import Path
74
from typing import List, Optional
85

@@ -11,7 +8,7 @@
118
from sqlalchemy import text
129
from sqlalchemy.ext.asyncio import create_async_engine
1310

14-
from codegate.config import Config
11+
from codegate.db.fim_cache import FimCache
1512
from codegate.db.models import Alert, Output, Prompt
1613
from codegate.db.queries import (
1714
AsyncQuerier,
@@ -22,7 +19,7 @@
2219

2320
logger = structlog.get_logger("codegate")
2421
alert_queue = asyncio.Queue()
25-
fim_entries = {}
22+
fim_cache = FimCache()
2623

2724

2825
class DbCodeGate:
@@ -183,47 +180,6 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
183180
logger.debug(f"Recorded alerts: {recorded_alerts}")
184181
return recorded_alerts
185182

186-
def _extract_request_message(self, request: str) -> Optional[dict]:
187-
"""Extract the user message from the FIM request"""
188-
try:
189-
parsed_request = json.loads(request)
190-
except Exception as e:
191-
logger.exception(f"Failed to extract request message: {request}", error=str(e))
192-
return None
193-
194-
messages = [message for message in parsed_request["messages"] if message["role"] == "user"]
195-
if len(messages) != 1:
196-
logger.warning(f"Expected one user message, found {len(messages)}.")
197-
return None
198-
199-
content_message = messages[0].get("content")
200-
return content_message
201-
202-
def _create_hash_key(self, message: str, provider: str) -> str:
203-
"""Creates a hash key from the message and includes the provider"""
204-
# Try to extract the path from the FIM message. The path is in FIM request in these formats:
205-
# folder/testing_file.py
206-
# Path: file3.py
207-
pattern = r"^#.*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
208-
matches = re.findall(pattern, message, re.MULTILINE)
209-
# If no path is found, hash the entire prompt message.
210-
if not matches:
211-
logger.warning("No path found in messages. Creating hash cache from message.")
212-
message_to_hash = f"{message}-{provider}"
213-
else:
214-
# Copilot puts the path at the top of the file. Continue providers contain
215-
# several paths, the one in which the fim is triggered is the last one.
216-
if provider == "copilot":
217-
filepath = matches[0]
218-
else:
219-
filepath = matches[-1]
220-
message_to_hash = f"{filepath}-{provider}"
221-
222-
logger.debug(f"Message to hash: {message_to_hash}")
223-
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
224-
logger.debug(f"Hashed contnet: {hashed_content}")
225-
return hashed_content
226-
227183
def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
228184
"""Check if the context should be recorded in DB"""
229185
if context is None or context.metadata.get("stored_in_db", False):
@@ -237,37 +193,22 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
237193
if context.input_request.type != "fim":
238194
return True
239195

240-
# Couldn't process the user message. Skip creating a mapping entry.
241-
message = self._extract_request_message(context.input_request.request)
242-
if message is None:
243-
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
244-
return False
245-
246-
hash_key = self._create_hash_key(message, context.input_request.provider)
247-
old_timestamp = fim_entries.get(hash_key, None)
248-
if old_timestamp is None:
249-
fim_entries[hash_key] = context.input_request.timestamp
250-
return True
196+
return fim_cache.could_store_fim_request(context)
251197

252-
elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds()
253-
if elapsed_seconds < Config.get_config().max_fim_hash_lifetime:
198+
async def record_context(self, context: Optional[PipelineContext]) -> None:
199+
try:
200+
if not self._should_record_context(context):
201+
return
202+
await self.record_request(context.input_request)
203+
await self.record_outputs(context.output_responses)
204+
await self.record_alerts(context.alerts_raised)
205+
context.metadata["stored_in_db"] = True
254206
logger.info(
255-
f"Skipping DB context recording. "
256-
f"Elapsed time since last FIM cache: {timedelta(seconds=elapsed_seconds)}."
207+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
208+
f"Alerts: {len(context.alerts_raised)}."
257209
)
258-
return False
259-
260-
async def record_context(self, context: Optional[PipelineContext]) -> None:
261-
if not self._should_record_context(context):
262-
return
263-
await self.record_request(context.input_request)
264-
await self.record_outputs(context.output_responses)
265-
await self.record_alerts(context.alerts_raised)
266-
context.metadata["stored_in_db"] = True
267-
logger.info(
268-
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
269-
f"Alerts: {len(context.alerts_raised)}."
270-
)
210+
except Exception as e:
211+
logger.error(f"Failed to record context: {context}.", error=str(e))
271212

272213

273214
class DbReader(DbCodeGate):

src/codegate/db/fim_cache.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import datetime
2+
import hashlib
3+
import json
4+
import re
5+
from typing import Dict, List, Optional
6+
7+
import structlog
8+
from pydantic import BaseModel
9+
10+
from codegate.config import Config
11+
from codegate.db.models import Alert
12+
from codegate.pipeline.base import AlertSeverity, PipelineContext
13+
14+
logger = structlog.get_logger("codegate")
15+
16+
17+
class CachedFim(BaseModel):
18+
19+
timestamp: datetime.datetime
20+
critical_alerts: List[Alert]
21+
22+
23+
class FimCache:
24+
25+
def __init__(self):
26+
self.cache: Dict[str, CachedFim] = {}
27+
28+
def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
29+
"""Extract the user message from the FIM request"""
30+
try:
31+
parsed_request = json.loads(request)
32+
except Exception as e:
33+
logger.error(f"Failed to extract request message: {request}", error=str(e))
34+
return None
35+
36+
if not isinstance(parsed_request, dict):
37+
logger.warning(f"Expected a dictionary, got {type(parsed_request)}.")
38+
return None
39+
40+
messages = [
41+
message
42+
for message in parsed_request.get("messages", [])
43+
if isinstance(message, dict) and message.get("role", "") == "user"
44+
]
45+
if len(messages) != 1:
46+
logger.warning(f"Expected one user message, found {len(messages)}.")
47+
return None
48+
49+
content_message = messages[0].get("content")
50+
return content_message
51+
52+
def _match_filepath(self, message: str, provider: str) -> Optional[str]:
53+
# Try to extract the path from the FIM message. The path is in FIM request as a comment:
54+
# folder/testing_file.py
55+
# Path: file3.py
56+
# // Path: file3.js <-- Javascript
57+
pattern = r"^(#|//|<!--|--|%|;).*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
58+
matches = re.findall(pattern, message, re.MULTILINE)
59+
# If no path is found, hash the entire prompt message.
60+
if not matches:
61+
return None
62+
63+
# Extract only the paths (2nd group from the match)
64+
paths = [match[1] for match in matches]
65+
66+
# Copilot puts the path at the top of the file. Continue providers contain
67+
# several paths, the one in which the fim is triggered is the last one.
68+
if provider == "copilot":
69+
return paths[0]
70+
else:
71+
return paths[-1]
72+
73+
def _calculate_hash_key(self, message: str, provider: str) -> str:
74+
"""Creates a hash key from the message and includes the provider"""
75+
filepath = self._match_filepath(message, provider)
76+
if filepath is None:
77+
logger.warning("No path found in messages. Creating hash key from message.")
78+
message_to_hash = f"{message}-{provider}"
79+
else:
80+
message_to_hash = f"{filepath}-{provider}"
81+
82+
logger.debug(f"Message to hash: {message_to_hash}")
83+
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
84+
logger.debug(f"Hashed content: {hashed_content}")
85+
return hashed_content
86+
87+
def _add_cache_entry(self, hash_key: str, context: PipelineContext):
88+
"""Add a new cache entry"""
89+
critical_alerts = [
90+
alert
91+
for alert in context.alerts_raised
92+
if alert.trigger_category == AlertSeverity.CRITICAL.value
93+
]
94+
new_cache = CachedFim(
95+
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
96+
)
97+
self.cache[hash_key] = new_cache
98+
logger.info(f"Added cache entry for hash key: {hash_key}")
99+
100+
def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
101+
"""Check if there are new alerts present"""
102+
new_critical_alerts = [
103+
alert
104+
for alert in context.alerts_raised
105+
if alert.trigger_category == AlertSeverity.CRITICAL.value
106+
]
107+
return len(new_critical_alerts) > len(cached_entry.critical_alerts)
108+
109+
def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
110+
"""Check if the cached entry is old"""
111+
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
112+
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime
113+
114+
def could_store_fim_request(self, context: PipelineContext):
115+
# Couldn't process the user message. Skip creating a mapping entry.
116+
message = self._extract_message_from_fim_request(context.input_request.request)
117+
if message is None:
118+
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
119+
return False
120+
121+
hash_key = self._calculate_hash_key(message, context.input_request.provider)
122+
cached_entry = self.cache.get(hash_key, None)
123+
if cached_entry is None:
124+
self._add_cache_entry(hash_key, context)
125+
return True
126+
127+
if self._is_cached_entry_old(context, cached_entry):
128+
self._add_cache_entry(hash_key, context)
129+
return True
130+
131+
if self._are_new_alerts_present(context, cached_entry):
132+
self._add_cache_entry(hash_key, context)
133+
return True
134+
135+
logger.debug(f"FIM entry already in cache: {hash_key}.")
136+
return False

src/codegate/providers/copilot/provider.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
22
import re
33
import ssl
4-
from src.codegate.codegate_logging import setup_logging
5-
import structlog
64
from dataclasses import dataclass
75
from typing import Dict, List, Optional, Tuple, Union
86
from urllib.parse import unquote, urljoin, urlparse
97

8+
import structlog
109
from litellm.types.utils import Delta, ModelResponse, StreamingChoices
1110

1211
from codegate.ca.codegate_ca import CertificateAuthority
@@ -22,6 +21,7 @@
2221
CopilotPipeline,
2322
)
2423
from codegate.providers.copilot.streaming import SSEProcessor
24+
from src.codegate.codegate_logging import setup_logging
2525

2626
setup_logging()
2727
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")
@@ -206,7 +206,7 @@ async def _request_to_target(self, headers: list[str], body: bytes):
206206
logger.debug("=" * 40)
207207

208208
for i in range(0, len(body), CHUNK_SIZE):
209-
chunk = body[i: i + CHUNK_SIZE]
209+
chunk = body[i : i + CHUNK_SIZE]
210210
self.target_transport.write(chunk)
211211

212212
def connection_made(self, transport: asyncio.Transport) -> None:
@@ -269,9 +269,7 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
269269
"""Check if adding new data would exceed buffer size limit"""
270270
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
271271

272-
async def _forward_data_through_pipeline(
273-
self, data: bytes
274-
) -> Union[HttpRequest, HttpResponse]:
272+
async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest, HttpResponse]:
275273
http_request = http_request_from_bytes(data)
276274
if not http_request:
277275
# we couldn't parse this into an HTTP request, so we just pass through
@@ -287,7 +285,7 @@ async def _forward_data_through_pipeline(
287285

288286
if context and context.shortcut_response:
289287
# Send shortcut response
290-
data_prefix = b'data:'
288+
data_prefix = b"data:"
291289
http_response = HttpResponse(
292290
http_request.version,
293291
200,
@@ -299,7 +297,7 @@ async def _forward_data_through_pipeline(
299297
"Content-Type: application/json",
300298
"Transfer-Encoding: chunked",
301299
],
302-
data_prefix + body
300+
data_prefix + body,
303301
)
304302
return http_response
305303

@@ -639,7 +637,7 @@ async def get_target_url(path: str) -> Optional[str]:
639637
# Check for prefix match
640638
for route in VALIDATED_ROUTES:
641639
# For prefix matches, keep the rest of the path
642-
remaining_path = path[len(route.path):]
640+
remaining_path = path[len(route.path) :]
643641
logger.debug(f"Remaining path: {remaining_path}")
644642
# Make sure we don't end up with double slashes
645643
if remaining_path and remaining_path.startswith("/"):
@@ -793,7 +791,7 @@ def data_received(self, data: bytes) -> None:
793791
self._proxy_transport_write(headers)
794792
logger.debug(f"Headers sent: {headers}")
795793

796-
data = data[header_end + 4:]
794+
data = data[header_end + 4 :]
797795

798796
self._process_chunk(data)
799797

tests/db/test_connection.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)