5
5
6
6
import structlog
7
7
from pydantic import BaseModel
8
- from sqlalchemy import text
8
+ from sqlalchemy import TextClause , text
9
9
from sqlalchemy .ext .asyncio import create_async_engine
10
10
11
11
from codegate .db .fim_cache import FimCache
@@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None):
30
30
current_dir = Path (__file__ ).parent
31
31
sqlite_path = (
32
32
current_dir .parent .parent .parent / "codegate_volume" / "db" / "codegate.db"
33
- )
34
- self ._db_path = Path (sqlite_path ).absolute ()
33
+ ) # type: ignore
34
+ self ._db_path = Path (sqlite_path ).absolute () # type: ignore
35
35
self ._db_path .parent .mkdir (parents = True , exist_ok = True )
36
36
logger .debug (f"Initializing DB from path: { self ._db_path } " )
37
37
engine_dict = {
@@ -82,15 +82,15 @@ async def init_db(self):
82
82
finally :
83
83
await self ._async_db_engine .dispose ()
84
84
85
- async def _insert_pydantic_model (
86
- self , model : BaseModel , sql_insert : text
85
+ async def _execute_update_pydantic_model (
86
+ self , model : BaseModel , sql_command : TextClause #
87
87
) -> Optional [BaseModel ]:
88
88
# There are create method in queries.py automatically generated by sqlc
89
89
# However, the methods are buggy for Pydancti and don't work as expected.
90
90
# Manually writing the SQL query to insert Pydantic models.
91
91
async with self ._async_db_engine .begin () as conn :
92
92
try :
93
- result = await conn .execute (sql_insert , model .model_dump ())
93
+ result = await conn .execute (sql_command , model .model_dump ())
94
94
row = result .first ()
95
95
if row is None :
96
96
return None
@@ -99,7 +99,7 @@ async def _insert_pydantic_model(
99
99
model_class = model .__class__
100
100
return model_class (** row ._asdict ())
101
101
except Exception as e :
102
- logger .error (f"Failed to insert model: { model } ." , error = str (e ))
102
+ logger .error (f"Failed to update model: { model } ." , error = str (e ))
103
103
return None
104
104
105
105
async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
@@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
112
112
RETURNING *
113
113
"""
114
114
)
115
- recorded_request = await self ._insert_pydantic_model (prompt_params , sql )
115
+ recorded_request = await self ._execute_update_pydantic_model (prompt_params , sql )
116
116
# Uncomment to debug the recorded request
117
117
# logger.debug(f"Recorded request: {recorded_request}")
118
- return recorded_request
118
+ return recorded_request # type: ignore
119
119
120
- async def record_outputs (self , outputs : List [Output ]) -> Optional [Output ]:
120
+ async def update_request (self , initial_id : str ,
121
+ prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
122
+ if prompt_params is None :
123
+ return None
124
+ prompt_params .id = initial_id # overwrite the initial id of the request
125
+ sql = text (
126
+ """
127
+ UPDATE prompts
128
+ SET timestamp = :timestamp, provider = :provider, request = :request, type = :type
129
+ WHERE id = :id
130
+ RETURNING *
131
+ """
132
+ )
133
+ updated_request = await self ._execute_update_pydantic_model (prompt_params , sql )
134
+ # Uncomment to debug the recorded request
135
+ # logger.debug(f"Recorded request: {recorded_request}")
136
+ return updated_request # type: ignore
137
+
138
+ async def record_outputs (self , outputs : List [Output ],
139
+ initial_id : Optional [str ]) -> Optional [Output ]:
121
140
if not outputs :
122
141
return
123
142
124
143
first_output = outputs [0 ]
125
144
# Create a single entry on DB but encode all of the chunks in the stream as a list
126
145
# of JSON objects in the field `output`
146
+ if initial_id :
147
+ first_output .prompt_id = initial_id
127
148
output_db = Output (
128
149
id = first_output .id ,
129
150
prompt_id = first_output .prompt_id ,
@@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
143
164
RETURNING *
144
165
"""
145
166
)
146
- recorded_output = await self ._insert_pydantic_model (output_db , sql )
167
+ recorded_output = await self ._execute_update_pydantic_model (output_db , sql )
147
168
# Uncomment to debug
148
169
# logger.debug(f"Recorded output: {recorded_output}")
149
- return recorded_output
170
+ return recorded_output # type: ignore
150
171
151
- async def record_alerts (self , alerts : List [Alert ]) -> List [Alert ]:
172
+ async def record_alerts (self , alerts : List [Alert ], initial_id : Optional [ str ] ) -> List [Alert ]:
152
173
if not alerts :
153
- return
174
+ return []
154
175
sql = text (
155
176
"""
156
177
INSERT INTO alerts (
@@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
167
188
async with asyncio .TaskGroup () as tg :
168
189
for alert in alerts :
169
190
try :
170
- result = tg .create_task (self ._insert_pydantic_model (alert , sql ))
191
+ if initial_id :
192
+ alert .prompt_id = initial_id
193
+ result = tg .create_task (self ._execute_update_pydantic_model (alert , sql ))
171
194
alerts_tasks .append (result )
172
195
except Exception as e :
173
196
logger .error (f"Failed to record alert: { alert } ." , error = str (e ))
@@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
182
205
# logger.debug(f"Recorded alerts: {recorded_alerts}")
183
206
return recorded_alerts
184
207
185
- def _should_record_context (self , context : Optional [PipelineContext ]) -> bool :
186
- """Check if the context should be recorded in DB"""
208
+ def _should_record_context (self , context : Optional [PipelineContext ]) -> tuple :
209
+ """Check if the context should be recorded in DB and determine the action. """
187
210
if context is None or context .metadata .get ("stored_in_db" , False ):
188
- return False
211
+ return False , None , None
189
212
190
213
if not context .input_request :
191
214
logger .warning ("No input request found. Skipping recording context." )
192
- return False
215
+ return False , None , None
193
216
194
217
# If it's not a FIM prompt, we don't need to check anything else.
195
218
if context .input_request .type != "fim" :
196
- return True
219
+ return True , 'add' , '' # Default to add if not FIM, since no cache check is required
197
220
198
- return fim_cache .could_store_fim_request (context )
221
+ return fim_cache .could_store_fim_request (context ) # type: ignore
199
222
200
223
async def record_context (self , context : Optional [PipelineContext ]) -> None :
201
224
try :
202
- if not self ._should_record_context (context ):
225
+ if not context :
226
+ logger .info ("No context provided, skipping" )
203
227
return
204
- await self .record_request (context .input_request )
205
- await self .record_outputs (context .output_responses )
206
- await self .record_alerts (context .alerts_raised )
207
- context .metadata ["stored_in_db" ] = True
208
- logger .info (
209
- f"Recorded context in DB. Output chunks: { len (context .output_responses )} . "
210
- f"Alerts: { len (context .alerts_raised )} ."
211
- )
228
+ should_record , action , initial_id = self ._should_record_context (context )
229
+ if not should_record :
230
+ logger .info ("Skipping record of context, not needed" )
231
+ return
232
+ if action == 'add' :
233
+ await self .record_request (context .input_request )
234
+ await self .record_outputs (context .output_responses , None )
235
+ await self .record_alerts (context .alerts_raised , None )
236
+ context .metadata ["stored_in_db" ] = True
237
+ logger .info (
238
+ f"Recorded context in DB. Output chunks: { len (context .output_responses )} . "
239
+ f"Alerts: { len (context .alerts_raised )} ."
240
+ )
241
+ else :
242
+ # update them
243
+ await self .update_request (initial_id , context .input_request )
244
+ await self .record_outputs (context .output_responses , initial_id )
245
+ await self .record_alerts (context .alerts_raised , initial_id )
246
+ context .metadata ["stored_in_db" ] = True
247
+ logger .info (
248
+ f"Recorded context in DB. Output chunks: { len (context .output_responses )} . "
249
+ f"Alerts: { len (context .alerts_raised )} ."
250
+ )
212
251
except Exception as e :
213
252
logger .error (f"Failed to record context: { context } ." , error = str (e ))
214
253
0 commit comments