Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1683184

Browse files
WIP: Record the prompt at the end of pipeline. Keep DB objects in context
1 parent 17a99ec commit 1683184

File tree

7 files changed

+142
-71
lines changed

7 files changed

+142
-71
lines changed

src/codegate/db/connection.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
GetAlertsWithPromptAndOutputRow,
1919
GetPromptWithOutputsRow,
2020
)
21+
from codegate.pipeline.base import PipelineContext
2122

2223
logger = structlog.get_logger("codegate")
2324
alert_queue = asyncio.Queue()
@@ -104,37 +105,20 @@ async def _insert_pydantic_model(
104105
return None
105106

106107
async def record_request(
107-
self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str
108+
self, prompt_params: Optional[Prompt] = None
108109
) -> Optional[Prompt]:
109-
request_str = None
110-
if isinstance(normalized_request, BaseModel):
111-
request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True)
112-
else:
113-
try:
114-
request_str = json.dumps(normalized_request)
115-
except Exception as e:
116-
logger.error(f"Failed to serialize output: {normalized_request}", error=str(e))
117-
118-
if request_str is None:
119-
logger.warning("No request found to record.")
120-
return
121-
122-
# Create a new prompt record
123-
prompt_params = Prompt(
124-
id=str(uuid.uuid4()), # Generate a new UUID for the prompt
125-
timestamp=datetime.datetime.now(datetime.timezone.utc),
126-
provider=provider_str,
127-
type="fim" if is_fim_request else "chat",
128-
request=request_str,
129-
)
110+
if prompt_params is None:
111+
return None
130112
sql = text(
131113
"""
132114
INSERT INTO prompts (id, timestamp, provider, request, type)
133115
VALUES (:id, :timestamp, :provider, :request, :type)
134116
RETURNING *
135117
"""
136118
)
137-
return await self._insert_pydantic_model(prompt_params, sql)
119+
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
120+
logger.info(f"Recorded request: {recorded_request}")
121+
return recorded_request
138122

139123
async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]:
140124
output_params = Output(
@@ -152,6 +136,28 @@ async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Outp
152136
)
153137
return await self._insert_pydantic_model(output_params, sql)
154138

139+
async def record_outputs(self, outputs: List[Output]) -> List[Output]:
140+
if not outputs:
141+
return
142+
sql = text(
143+
"""
144+
INSERT INTO outputs (id, prompt_id, timestamp, output)
145+
VALUES (:id, :prompt_id, :timestamp, :output)
146+
RETURNING *
147+
"""
148+
)
149+
# We can insert each alert independently in parallel.
150+
outputs_tasks = []
151+
async with asyncio.TaskGroup() as tg:
152+
for output in outputs:
153+
try:
154+
outputs_tasks.append(tg.create_task(self._insert_pydantic_model(output, sql)))
155+
except Exception as e:
156+
logger.error(f"Failed to record alert: {output}.", error=str(e))
157+
recorded_outputs = [output.result() for output in outputs_tasks]
158+
logger.info(f"Recorded outputs: {recorded_outputs}")
159+
return recorded_outputs
160+
155161
async def record_output_stream(
156162
self, prompt: Prompt, model_response: AsyncIterator
157163
) -> AsyncGenerator:
@@ -193,7 +199,7 @@ async def record_output_non_stream(
193199

194200
return await self._record_output(prompt, output_str)
195201

196-
async def record_alerts(self, alerts: List[Alert]) -> None:
202+
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
197203
if not alerts:
198204
return
199205
sql = text(
@@ -208,15 +214,25 @@ async def record_alerts(self, alerts: List[Alert]) -> None:
208214
"""
209215
)
210216
# We can insert each alert independently in parallel.
217+
alerts_tasks = []
211218
async with asyncio.TaskGroup() as tg:
212219
for alert in alerts:
213220
try:
214221
result = tg.create_task(self._insert_pydantic_model(alert, sql))
222+
alerts_tasks.append(result)
215223
if result and alert.trigger_category == "critical":
216224
await alert_queue.put(f"New alert detected: {alert.timestamp}")
217225
except Exception as e:
218226
logger.error(f"Failed to record alert: {alert}.", error=str(e))
219-
return None
227+
recorded_alerts = [alert.result() for alert in alerts_tasks]
228+
logger.info(f"Recorded alerts: {recorded_alerts}")
229+
return recorded_alerts
230+
231+
async def record_context(self, context: PipelineContext) -> None:
232+
logger.info(f"Recording context: {context}")
233+
await self.record_request(context.input_request)
234+
await self.record_outputs(context.output_responses)
235+
await self.record_alerts(context.alerts_raised)
220236

221237

222238
class DbReader(DbCodeGate):

src/codegate/pipeline/base.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from abc import ABC, abstractmethod
66
from dataclasses import dataclass, field
77
from enum import Enum
8-
from typing import Any, Dict, List, Optional
8+
from typing import Any, AsyncIterator, Dict, List, Optional
99

1010
import structlog
11-
from litellm import ChatCompletionRequest
11+
from litellm import ChatCompletionRequest, ModelResponse
12+
from pydantic import BaseModel
1213

13-
from codegate.db.models import Alert
14+
from codegate.db.models import Alert, Output, Prompt
1415
from codegate.pipeline.secrets.manager import SecretsManager
1516

1617
logger = structlog.get_logger("codegate")
@@ -73,6 +74,9 @@ class PipelineContext:
7374
metadata: Dict[str, Any] = field(default_factory=dict)
7475
sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None)
7576
alerts_raised: List[Alert] = field(default_factory=list)
77+
prompt_id: Optional[str] = field(default_factory=lambda: None)
78+
input_request: Optional[Prompt] = field(default_factory=lambda: None)
79+
output_responses: List[Output] = field(default_factory=list)
7680

7781
def add_code_snippet(self, snippet: CodeSnippet):
7882
self.code_snippets.append(snippet)
@@ -90,9 +94,8 @@ def add_alert(
9094
"""
9195
Add an alert to the pipeline step alerts_raised.
9296
"""
93-
if not self.metadata.get("prompt_id"):
94-
logger.warning("No prompt_id found in context. Alert will not be created")
95-
return
97+
if self.prompt_id is None:
98+
self.prompt_id = str(uuid.uuid4())
9699

97100
if not code_snippet and not trigger_string:
98101
logger.warning("No code snippet or trigger string provided for alert. Will not create")
@@ -103,15 +106,57 @@ def add_alert(
103106
self.alerts_raised.append(
104107
Alert(
105108
id=str(uuid.uuid4()),
106-
prompt_id=self.metadata["prompt_id"],
109+
prompt_id=self.prompt_id,
107110
code_snippet=code_snippet_str,
108111
trigger_string=trigger_string,
109112
trigger_type=step_name,
110113
trigger_category=severity_category.value,
111114
timestamp=datetime.datetime.now(datetime.timezone.utc),
112115
)
113116
)
117+
logger.info(f"Added alert to context: {self.alerts_raised[-1]}")
118+
119+
def add_input_request(
120+
self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider: str
121+
) -> None:
122+
try:
123+
if self.prompt_id is None:
124+
self.prompt_id = str(uuid.uuid4())
114125

126+
request_str = json.dumps(normalized_request)
127+
128+
self.input_request = Prompt(
129+
id=self.prompt_id,
130+
timestamp=datetime.datetime.now(datetime.timezone.utc),
131+
provider=provider,
132+
type="fim" if is_fim_request else "chat",
133+
request=request_str,
134+
)
135+
logger.info(f"Added input request to context: {self.input_request}")
136+
except Exception as e:
137+
logger.warning(f"Failed to serialize input request: {normalized_request}", error=str(e))
138+
139+
def add_output(self, model_response: ModelResponse) -> None:
140+
try:
141+
if self.prompt_id is None:
142+
self.prompt_id = str(uuid.uuid4())
143+
144+
if isinstance(model_response, BaseModel):
145+
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
146+
else:
147+
output_str = json.dumps(model_response)
148+
149+
self.output_responses.append(Output(
150+
id=self.prompt_id,
151+
prompt_id=self.input_request.id,
152+
timestamp=datetime.datetime.now(datetime.timezone.utc),
153+
output=output_str,
154+
)
155+
)
156+
logger.info(f"Added output to context: {self.output_responses[-1]}")
157+
except Exception as e:
158+
logger.error(f"Failed to serialize output: {model_response}", error=str(e))
159+
return
115160

116161
@dataclass
117162
class PipelineResponse:
@@ -212,16 +257,17 @@ async def process(
212257

213258

214259
class InputPipelineInstance:
215-
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager):
260+
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool):
216261
self.pipeline_steps = pipeline_steps
217262
self.secret_manager = secret_manager
263+
self.is_fim = is_fim
218264
self.context = PipelineContext()
219265

220266
async def process_request(
221267
self,
222268
request: ChatCompletionRequest,
223269
provider: str,
224-
prompt_id: str,
270+
# prompt_id: str,
225271
model: str,
226272
api_key: Optional[str] = None,
227273
api_base: Optional[str] = None,
@@ -236,7 +282,7 @@ async def process_request(
236282
provider=provider,
237283
api_base=api_base,
238284
)
239-
self.context.metadata["prompt_id"] = prompt_id
285+
# self.context.metadata["prompt_id"] = prompt_id
240286
self.context.metadata["extra_headers"] = extra_headers
241287
current_request = request
242288

@@ -254,23 +300,26 @@ async def process_request(
254300
if result.context is not None:
255301
self.context = result.context
256302

303+
# Create the input request at the end so we make sure the secrets are obfuscated
304+
self.context.add_input_request(current_request, is_fim_request=self.is_fim, provider=provider)
257305
return PipelineResult(request=current_request, context=self.context)
258306

259307

260308
class SequentialPipelineProcessor:
261-
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager):
309+
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool):
262310
self.pipeline_steps = pipeline_steps
263311
self.secret_manager = secret_manager
312+
self.is_fim = is_fim
264313

265314
def create_instance(self) -> InputPipelineInstance:
266315
"""Create a new pipeline instance for processing a request"""
267-
return InputPipelineInstance(self.pipeline_steps, self.secret_manager)
316+
return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim)
268317

269318
async def process_request(
270319
self,
271320
request: ChatCompletionRequest,
272321
provider: str,
273-
prompt_id: str,
322+
# prompt_id: str,
274323
model: str,
275324
api_key: Optional[str] = None,
276325
api_base: Optional[str] = None,
@@ -279,5 +328,5 @@ async def process_request(
279328
"""Create a new pipeline instance and process the request"""
280329
instance = self.create_instance()
281330
return await instance.process_request(
282-
request, provider, prompt_id, model, api_key, api_base, extra_headers
331+
request, provider, model, api_key, api_base, extra_headers
283332
)

src/codegate/pipeline/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor:
3232
SystemPrompt(Config.get_config().prompts.default_chat),
3333
CodegateContextRetriever(),
3434
]
35-
return SequentialPipelineProcessor(input_steps, self.secrets_manager)
35+
return SequentialPipelineProcessor(input_steps, self.secrets_manager, is_fim=False)
3636

3737
def create_fim_pipeline(self) -> SequentialPipelineProcessor:
3838
fim_steps: List[PipelineStep] = [
3939
CodegateSecrets(),
4040
]
41-
return SequentialPipelineProcessor(fim_steps, self.secrets_manager)
41+
return SequentialPipelineProcessor(fim_steps, self.secrets_manager, is_fim=True)
4242

4343
def create_output_pipeline(self) -> OutputPipelineProcessor:
4444
output_steps: List[OutputPipelineStep] = [

src/codegate/pipeline/output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ async def process_stream(
115115
async for chunk in stream:
116116
# Store chunk content in buffer
117117
self._buffer_chunk(chunk)
118+
self._input_context.add_output(chunk)
118119

119120
# Process chunk through each step of the pipeline
120121
current_chunks = [chunk]

src/codegate/providers/base.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def _run_input_pipeline(
109109
api_key: Optional[str],
110110
api_base: Optional[str],
111111
is_fim_request: bool,
112-
prompt_id: str,
112+
# prompt_id: str,
113113
) -> PipelineResult:
114114
# Decide which pipeline processor to use
115115
if is_fim_request:
@@ -125,7 +125,7 @@ async def _run_input_pipeline(
125125
result = await pipeline_processor.process_request(
126126
request=normalized_request,
127127
provider=self.provider_route_name,
128-
prompt_id=prompt_id,
128+
# prompt_id=prompt_id,
129129
model=normalized_request.get("model"),
130130
api_key=api_key,
131131
api_base=api_base,
@@ -194,10 +194,11 @@ async def _cleanup_after_streaming(
194194
async for item in stream:
195195
yield item
196196
finally:
197-
# Ensure sensitive data is cleaned up after the stream is consumed
198-
if context and context.sensitive:
199-
context.sensitive.secure_cleanup()
200-
await self._db_recorder.record_alerts(context.alerts_raised)
197+
if context:
198+
await self._db_recorder.record_context(context)
199+
# Ensure sensitive data is cleaned up after the stream is consumed
200+
if context.sensitive:
201+
context.sensitive.secure_cleanup()
201202

202203
async def complete(
203204
self, data: Dict, api_key: Optional[str], is_fim_request: bool
@@ -215,22 +216,22 @@ async def complete(
215216
"""
216217
normalized_request = self._input_normalizer.normalize(data)
217218
streaming = normalized_request.get("stream", False)
218-
prompt_db = await self._db_recorder.record_request(
219-
normalized_request, is_fim_request, self.provider_route_name
220-
)
219+
# prompt_db = await self._db_recorder.record_request(
220+
# normalized_request, is_fim_request, self.provider_route_name
221+
# )
221222

222-
prompt_db_id = prompt_db.id if prompt_db is not None else None
223+
# prompt_db_id = prompt_db.id if prompt_db is not None else None
223224
input_pipeline_result = await self._run_input_pipeline(
224225
normalized_request,
225226
api_key,
226227
data.get("base_url"),
227228
is_fim_request,
228-
prompt_id=prompt_db_id,
229+
# prompt_id=prompt_db_id,
229230
)
230231
if input_pipeline_result.response:
231-
await self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
232+
# await self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
232233
return await self._pipeline_response_formatter.handle_pipeline_response(
233-
input_pipeline_result.response, streaming, prompt_db=prompt_db
234+
input_pipeline_result.response, streaming, context=input_pipeline_result.context
234235
)
235236

236237
provider_request = self._input_normalizer.denormalize(input_pipeline_result.request)
@@ -247,17 +248,18 @@ async def complete(
247248
normalized_response = self._output_normalizer.normalize(model_response)
248249
pipeline_output = self._run_output_pipeline(normalized_response)
249250
# Record the output and alerts in the database can be done in parallel
250-
async with asyncio.TaskGroup() as tg:
251-
tg.create_task(
252-
self._db_recorder.record_output_non_stream(prompt_db, model_response)
253-
)
254-
if input_pipeline_result and input_pipeline_result.context:
255-
tg.create_task(
256-
self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
257-
)
251+
# async with asyncio.TaskGroup() as tg:
252+
# tg.create_task(
253+
# self._db_recorder.record_output_non_stream(prompt_db, model_response)
254+
# )
255+
# if input_pipeline_result and input_pipeline_result.context:
256+
# tg.create_task(
257+
# self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
258+
# )
259+
await self._db_recorder.record_context(input_pipeline_result.context)
258260
return self._output_normalizer.denormalize(pipeline_output)
259261

260-
model_response = self._db_recorder.record_output_stream(prompt_db, model_response)
262+
# model_response = self._db_recorder.record_output_stream(prompt_db, model_response)
261263
pipeline_output_stream = await self._run_output_stream_pipeline(
262264
input_pipeline_result.context, model_response, is_fim_request=is_fim_request
263265
)

0 commit comments

Comments
 (0)