1
1
import asyncio
2
- import copy
3
- import datetime
4
2
import json
5
- import uuid
6
3
from pathlib import Path
7
- from typing import AsyncGenerator , AsyncIterator , List , Optional
4
+ from typing import List , Optional
8
5
9
6
import structlog
10
- from litellm import ChatCompletionRequest , ModelResponse
11
7
from pydantic import BaseModel
12
8
from sqlalchemy import text
13
9
from sqlalchemy .ext .asyncio import create_async_engine
@@ -35,7 +31,7 @@ def __init__(self, sqlite_path: Optional[str] = None):
35
31
)
36
32
self ._db_path = Path (sqlite_path ).absolute ()
37
33
self ._db_path .parent .mkdir (parents = True , exist_ok = True )
38
- logger .debug (f"Initializing DB from path: { self ._db_path } " )
34
+ logger .info (f"Initializing DB from path: { self ._db_path } " )
39
35
engine_dict = {
40
36
"url" : f"sqlite+aiosqlite:///{ self ._db_path } " ,
41
37
"echo" : False , # Set to False in production
@@ -104,9 +100,7 @@ async def _insert_pydantic_model(
104
100
logger .error (f"Failed to insert model: { model } ." , error = str (e ))
105
101
return None
106
102
107
- async def record_request (
108
- self , prompt_params : Optional [Prompt ] = None
109
- ) -> Optional [Prompt ]:
103
+ async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
110
104
if prompt_params is None :
111
105
return None
112
106
sql = text (
@@ -117,87 +111,38 @@ async def record_request(
117
111
"""
118
112
)
119
113
recorded_request = await self ._insert_pydantic_model (prompt_params , sql )
120
- logger .info (f"Recorded request: { recorded_request } " )
114
+ logger .debug (f"Recorded request: { recorded_request } " )
121
115
return recorded_request
122
116
123
- async def _record_output (self , prompt : Prompt , output_str : str ) -> Optional [Output ]:
124
- output_params = Output (
125
- id = str (uuid .uuid4 ()),
126
- prompt_id = prompt .id ,
127
- timestamp = datetime .datetime .now (datetime .timezone .utc ),
128
- output = output_str ,
129
- )
130
- sql = text (
131
- """
132
- INSERT INTO outputs (id, prompt_id, timestamp, output)
133
- VALUES (:id, :prompt_id, :timestamp, :output)
134
- RETURNING *
135
- """
136
- )
137
- return await self ._insert_pydantic_model (output_params , sql )
138
-
139
- async def record_outputs (self , outputs : List [Output ]) -> List [Output ]:
117
+ async def record_outputs (self , outputs : List [Output ]) -> Optional [Output ]:
140
118
if not outputs :
141
119
return
120
+
121
+ first_output = outputs [0 ]
122
+ # Create a single entry on DB but encode all of the chunks in the stream as a list
123
+ # of JSON objects in the field `output`
124
+ output_db = Output (
125
+ id = first_output .id ,
126
+ prompt_id = first_output .prompt_id ,
127
+ timestamp = first_output .timestamp ,
128
+ output = first_output .output ,
129
+ )
130
+ full_outputs = []
131
+ # Just store the model respnses in the list of JSON objects.
132
+ for output in outputs :
133
+ full_outputs .append (output .output )
134
+ output_db .output = json .dumps (full_outputs )
135
+
142
136
sql = text (
143
137
"""
144
138
INSERT INTO outputs (id, prompt_id, timestamp, output)
145
139
VALUES (:id, :prompt_id, :timestamp, :output)
146
140
RETURNING *
147
141
"""
148
142
)
149
- # We can insert each alert independently in parallel.
150
- outputs_tasks = []
151
- async with asyncio .TaskGroup () as tg :
152
- for output in outputs :
153
- try :
154
- outputs_tasks .append (tg .create_task (self ._insert_pydantic_model (output , sql )))
155
- except Exception as e :
156
- logger .error (f"Failed to record alert: { output } ." , error = str (e ))
157
- recorded_outputs = [output .result () for output in outputs_tasks ]
158
- logger .info (f"Recorded outputs: { recorded_outputs } " )
159
- return recorded_outputs
160
-
161
- async def record_output_stream (
162
- self , prompt : Prompt , model_response : AsyncIterator
163
- ) -> AsyncGenerator :
164
- output_chunks = []
165
- async for chunk in model_response :
166
- if isinstance (chunk , BaseModel ):
167
- chunk_to_record = chunk .model_dump (exclude_none = True , exclude_unset = True )
168
- output_chunks .append (chunk_to_record )
169
- elif isinstance (chunk , dict ):
170
- output_chunks .append (copy .deepcopy (chunk ))
171
- else :
172
- output_chunks .append ({"chunk" : str (chunk )})
173
- yield chunk
174
-
175
- if output_chunks :
176
- # Record the output chunks
177
- output_str = json .dumps (output_chunks )
178
- await self ._record_output (prompt , output_str )
179
-
180
- async def record_output_non_stream (
181
- self , prompt : Optional [Prompt ], model_response : ModelResponse
182
- ) -> Optional [Output ]:
183
- if prompt is None :
184
- logger .warning ("No prompt found to record output." )
185
- return
186
-
187
- output_str = None
188
- if isinstance (model_response , BaseModel ):
189
- output_str = model_response .model_dump_json (exclude_none = True , exclude_unset = True )
190
- else :
191
- try :
192
- output_str = json .dumps (model_response )
193
- except Exception as e :
194
- logger .error (f"Failed to serialize output: { model_response } " , error = str (e ))
195
-
196
- if output_str is None :
197
- logger .warning ("No output found to record." )
198
- return
199
-
200
- return await self ._record_output (prompt , output_str )
143
+ recorded_output = await self ._insert_pydantic_model (output_db , sql )
144
+ logger .debug (f"Recorded output: { recorded_output } " )
145
+ return recorded_output
201
146
202
147
async def record_alerts (self , alerts : List [Alert ]) -> List [Alert ]:
203
148
if not alerts :
@@ -220,16 +165,24 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
220
165
try :
221
166
result = tg .create_task (self ._insert_pydantic_model (alert , sql ))
222
167
alerts_tasks .append (result )
223
- if result and alert .trigger_category == "critical" :
224
- await alert_queue .put (f"New alert detected: { alert .timestamp } " )
225
168
except Exception as e :
226
169
logger .error (f"Failed to record alert: { alert } ." , error = str (e ))
227
- recorded_alerts = [alert .result () for alert in alerts_tasks ]
228
- logger .info (f"Recorded alerts: { recorded_alerts } " )
170
+
171
+ recorded_alerts = []
172
+ for alert_coro in alerts_tasks :
173
+ alert_result = alert_coro .result ()
174
+ recorded_alerts .append (alert_result )
175
+ if alert_result and alert_result .trigger_category == "critical" :
176
+ await alert_queue .put (f"New alert detected: { alert .timestamp } " )
177
+
178
+ logger .debug (f"Recorded alerts: { recorded_alerts } " )
229
179
return recorded_alerts
230
180
231
181
async def record_context (self , context : PipelineContext ) -> None :
232
- logger .info (f"Recording context: { context } " )
182
+ logger .info (
183
+ f"Recording context in DB. Output chunks: { len (context .output_responses )} . "
184
+ f"Alerts: { len (context .alerts_raised )} ."
185
+ )
233
186
await self .record_request (context .input_request )
234
187
await self .record_outputs (context .output_responses )
235
188
await self .record_alerts (context .alerts_raised )
0 commit comments