Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 4a48941

Browse files
Changed flow for recording objects on DB. Keep objects in context
With this change the objects that are going to be stored in DB are kept in the `context` of the pipeline. The pipeline and its `context` are used by all providers, including Copilot. We would need to find a good place in Copilot provider to record the context in DB, e.g. when all the chunks have been transmitted and the stream is about to be closed.
1 parent 1683184 commit 4a48941

File tree

11 files changed

+162
-165
lines changed

11 files changed

+162
-165
lines changed

src/codegate/dashboard/post_processing.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,39 @@ async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
8686
logger.warning(f"Error parsing output: {output_str}. {e}")
8787
return None, None
8888

89-
output_message = ""
89+
def _parse_single_output(single_output: dict) -> str:
90+
single_chat_id = single_output.get("id")
91+
single_output_message = ""
92+
for choice in single_output.get("choices", []):
93+
if not isinstance(choice, dict):
94+
continue
95+
content_dict = choice.get("delta", {}) or choice.get("message", {})
96+
single_output_message += content_dict.get("content", "")
97+
return single_output_message, single_chat_id
98+
99+
full_output_message = ""
90100
chat_id = None
91101
if isinstance(output, list):
92102
for output_chunk in output:
93-
if not isinstance(output_chunk, dict):
94-
continue
95-
chat_id = chat_id or output_chunk.get("id")
96-
for choice in output_chunk.get("choices", []):
97-
if not isinstance(choice, dict):
98-
continue
99-
delta_dict = choice.get("delta", {})
100-
output_message += delta_dict.get("content", "")
103+
output_message, output_chat_id = "", None
104+
if isinstance(output_chunk, dict):
105+
output_message, output_chat_id = _parse_single_output(output_chunk)
106+
elif isinstance(output_chunk, str):
107+
try:
108+
output_decoded = json.loads(output_chunk)
109+
output_message, output_chat_id = _parse_single_output(output_decoded)
110+
except Exception:
111+
logger.error(f"Error reading chunk: {output_chunk}")
112+
else:
113+
logger.warning(
114+
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
115+
)
116+
chat_id = chat_id or output_chat_id
117+
full_output_message += output_message
101118
elif isinstance(output, dict):
102-
chat_id = chat_id or output.get("id")
103-
for choice in output.get("choices", []):
104-
if not isinstance(choice, dict):
105-
continue
106-
output_message += choice.get("message", {}).get("content", "")
119+
full_output_message, chat_id = _parse_single_output(output)
107120

108-
return output_message, chat_id
121+
return full_output_message, chat_id
109122

110123

111124
async def _get_question_answer(
@@ -124,19 +137,23 @@ async def _get_question_answer(
124137
output_msg_str, chat_id = output_task.result()
125138

126139
# If we couldn't parse the request or output, return None
127-
if not request_msg_str or not output_msg_str or not chat_id:
140+
if not request_msg_str:
128141
return None, None
129142

130143
request_message = ChatMessage(
131144
message=request_msg_str,
132145
timestamp=row.timestamp,
133146
message_id=row.id,
134147
)
135-
output_message = ChatMessage(
136-
message=output_msg_str,
137-
timestamp=row.output_timestamp,
138-
message_id=row.output_id,
139-
)
148+
if output_msg_str:
149+
output_message = ChatMessage(
150+
message=output_msg_str,
151+
timestamp=row.output_timestamp,
152+
message_id=row.output_id,
153+
)
154+
else:
155+
output_message = None
156+
chat_id = row.id
140157
return QuestionAnswer(question=request_message, answer=output_message), chat_id
141158

142159

src/codegate/dashboard/request_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class QuestionAnswer(BaseModel):
2222
"""
2323

2424
question: ChatMessage
25-
answer: ChatMessage
25+
answer: Optional[ChatMessage]
2626

2727

2828
class PartialConversation(BaseModel):

src/codegate/db/connection.py

Lines changed: 37 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import asyncio
2-
import copy
3-
import datetime
42
import json
5-
import uuid
63
from pathlib import Path
7-
from typing import AsyncGenerator, AsyncIterator, List, Optional
4+
from typing import List, Optional
85

96
import structlog
10-
from litellm import ChatCompletionRequest, ModelResponse
117
from pydantic import BaseModel
128
from sqlalchemy import text
139
from sqlalchemy.ext.asyncio import create_async_engine
@@ -35,7 +31,7 @@ def __init__(self, sqlite_path: Optional[str] = None):
3531
)
3632
self._db_path = Path(sqlite_path).absolute()
3733
self._db_path.parent.mkdir(parents=True, exist_ok=True)
38-
logger.debug(f"Initializing DB from path: {self._db_path}")
34+
logger.info(f"Initializing DB from path: {self._db_path}")
3935
engine_dict = {
4036
"url": f"sqlite+aiosqlite:///{self._db_path}",
4137
"echo": False, # Set to False in production
@@ -104,9 +100,7 @@ async def _insert_pydantic_model(
104100
logger.error(f"Failed to insert model: {model}.", error=str(e))
105101
return None
106102

107-
async def record_request(
108-
self, prompt_params: Optional[Prompt] = None
109-
) -> Optional[Prompt]:
103+
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
110104
if prompt_params is None:
111105
return None
112106
sql = text(
@@ -117,87 +111,38 @@ async def record_request(
117111
"""
118112
)
119113
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
120-
logger.info(f"Recorded request: {recorded_request}")
114+
logger.debug(f"Recorded request: {recorded_request}")
121115
return recorded_request
122116

123-
async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]:
124-
output_params = Output(
125-
id=str(uuid.uuid4()),
126-
prompt_id=prompt.id,
127-
timestamp=datetime.datetime.now(datetime.timezone.utc),
128-
output=output_str,
129-
)
130-
sql = text(
131-
"""
132-
INSERT INTO outputs (id, prompt_id, timestamp, output)
133-
VALUES (:id, :prompt_id, :timestamp, :output)
134-
RETURNING *
135-
"""
136-
)
137-
return await self._insert_pydantic_model(output_params, sql)
138-
139-
async def record_outputs(self, outputs: List[Output]) -> List[Output]:
117+
async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
140118
if not outputs:
141119
return
120+
121+
first_output = outputs[0]
122+
# Create a single entry on DB but encode all of the chunks in the stream as a list
123+
# of JSON objects in the field `output`
124+
output_db = Output(
125+
id=first_output.id,
126+
prompt_id=first_output.prompt_id,
127+
timestamp=first_output.timestamp,
128+
output=first_output.output,
129+
)
130+
full_outputs = []
131+
# Just store the model respnses in the list of JSON objects.
132+
for output in outputs:
133+
full_outputs.append(output.output)
134+
output_db.output = json.dumps(full_outputs)
135+
142136
sql = text(
143137
"""
144138
INSERT INTO outputs (id, prompt_id, timestamp, output)
145139
VALUES (:id, :prompt_id, :timestamp, :output)
146140
RETURNING *
147141
"""
148142
)
149-
# We can insert each alert independently in parallel.
150-
outputs_tasks = []
151-
async with asyncio.TaskGroup() as tg:
152-
for output in outputs:
153-
try:
154-
outputs_tasks.append(tg.create_task(self._insert_pydantic_model(output, sql)))
155-
except Exception as e:
156-
logger.error(f"Failed to record alert: {output}.", error=str(e))
157-
recorded_outputs = [output.result() for output in outputs_tasks]
158-
logger.info(f"Recorded outputs: {recorded_outputs}")
159-
return recorded_outputs
160-
161-
async def record_output_stream(
162-
self, prompt: Prompt, model_response: AsyncIterator
163-
) -> AsyncGenerator:
164-
output_chunks = []
165-
async for chunk in model_response:
166-
if isinstance(chunk, BaseModel):
167-
chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True)
168-
output_chunks.append(chunk_to_record)
169-
elif isinstance(chunk, dict):
170-
output_chunks.append(copy.deepcopy(chunk))
171-
else:
172-
output_chunks.append({"chunk": str(chunk)})
173-
yield chunk
174-
175-
if output_chunks:
176-
# Record the output chunks
177-
output_str = json.dumps(output_chunks)
178-
await self._record_output(prompt, output_str)
179-
180-
async def record_output_non_stream(
181-
self, prompt: Optional[Prompt], model_response: ModelResponse
182-
) -> Optional[Output]:
183-
if prompt is None:
184-
logger.warning("No prompt found to record output.")
185-
return
186-
187-
output_str = None
188-
if isinstance(model_response, BaseModel):
189-
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
190-
else:
191-
try:
192-
output_str = json.dumps(model_response)
193-
except Exception as e:
194-
logger.error(f"Failed to serialize output: {model_response}", error=str(e))
195-
196-
if output_str is None:
197-
logger.warning("No output found to record.")
198-
return
199-
200-
return await self._record_output(prompt, output_str)
143+
recorded_output = await self._insert_pydantic_model(output_db, sql)
144+
logger.debug(f"Recorded output: {recorded_output}")
145+
return recorded_output
201146

202147
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
203148
if not alerts:
@@ -220,16 +165,24 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
220165
try:
221166
result = tg.create_task(self._insert_pydantic_model(alert, sql))
222167
alerts_tasks.append(result)
223-
if result and alert.trigger_category == "critical":
224-
await alert_queue.put(f"New alert detected: {alert.timestamp}")
225168
except Exception as e:
226169
logger.error(f"Failed to record alert: {alert}.", error=str(e))
227-
recorded_alerts = [alert.result() for alert in alerts_tasks]
228-
logger.info(f"Recorded alerts: {recorded_alerts}")
170+
171+
recorded_alerts = []
172+
for alert_coro in alerts_tasks:
173+
alert_result = alert_coro.result()
174+
recorded_alerts.append(alert_result)
175+
if alert_result and alert_result.trigger_category == "critical":
176+
await alert_queue.put(f"New alert detected: {alert.timestamp}")
177+
178+
logger.debug(f"Recorded alerts: {recorded_alerts}")
229179
return recorded_alerts
230180

231181
async def record_context(self, context: PipelineContext) -> None:
232-
logger.info(f"Recording context: {context}")
182+
logger.info(
183+
f"Recording context in DB. Output chunks: {len(context.output_responses)}. "
184+
f"Alerts: {len(context.alerts_raised)}."
185+
)
233186
await self.record_request(context.input_request)
234187
await self.record_outputs(context.output_responses)
235188
await self.record_alerts(context.alerts_raised)

0 commit comments

Comments
 (0)