Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 4a06bac

Browse files
Mocked db_recorder object in tests
1 parent 6a7c00a commit 4a06bac

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/codegate/pipeline/output.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
self,
7777
pipeline_steps: list[OutputPipelineStep],
7878
input_context: Optional[PipelineContext] = None,
79+
db_recorder: Optional[DbRecorder] = None,
7980
):
8081
self._input_context = input_context
8182
self._pipeline_steps = pipeline_steps
@@ -84,7 +85,10 @@ def __init__(
8485
# the remaining content in the buffer when the stream ends, we need
8586
# to store the parameters like model, timestamp, etc.
8687
self._buffered_chunk = None
87-
self._db_recorder = DbRecorder()
88+
if not db_recorder:
89+
self._db_recorder = DbRecorder()
90+
else:
91+
self._db_recorder = db_recorder
8892

8993
def _buffer_chunk(self, chunk: ModelResponse) -> None:
9094
"""

tests/pipeline/test_output.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List
2+
from unittest.mock import MagicMock
23

34
import pytest
45
from litellm import ModelResponse
@@ -94,7 +95,8 @@ async def test_single_step_processing(self):
9495
"""Test processing a stream through a single step"""
9596
step = MockOutputPipelineStep("test_step", modify_content=True)
9697
context = MockContext()
97-
instance = OutputPipelineInstance([step], context)
98+
db_recorder = MagicMock()
99+
instance = OutputPipelineInstance([step], context, db_recorder)
98100

99101
async def mock_stream():
100102
yield create_model_response("Hello")
@@ -118,7 +120,8 @@ async def test_multiple_steps_processing(self):
118120
MockOutputPipelineStep("step2", modify_content=True),
119121
]
120122
context = MockContext()
121-
instance = OutputPipelineInstance(steps, context)
123+
db_recorder = MagicMock()
124+
instance = OutputPipelineInstance(steps, context, db_recorder)
122125

123126
async def mock_stream():
124127
yield create_model_response("Hello")
@@ -197,7 +200,8 @@ async def process_chunk(
197200
return []
198201

199202
context = MockContext()
200-
instance = OutputPipelineInstance([ReplacementStep()], context)
203+
db_recorder = MagicMock()
204+
instance = OutputPipelineInstance([ReplacementStep()], context, db_recorder)
201205

202206
async def mock_stream():
203207
yield create_model_response("he")
@@ -221,7 +225,8 @@ async def test_buffer_processing(self):
221225
"""Test that content is properly buffered and cleared"""
222226
step = MockOutputPipelineStep("test_step")
223227
context = MockContext()
224-
instance = OutputPipelineInstance([step], context)
228+
db_recorder = MagicMock()
229+
instance = OutputPipelineInstance([step], context, db_recorder)
225230

226231
async def mock_stream():
227232
yield create_model_response("Hello")
@@ -242,7 +247,8 @@ async def test_empty_stream(self):
242247
"""Test handling of an empty stream"""
243248
step = MockOutputPipelineStep("test_step")
244249
context = MockContext()
245-
instance = OutputPipelineInstance([step], context)
250+
db_recorder = MagicMock()
251+
instance = OutputPipelineInstance([step], context, db_recorder)
246252

247253
async def mock_stream():
248254
if False:
@@ -275,7 +281,10 @@ async def process_chunk(
275281
assert input_context.metadata["test"] == "value"
276282
return [chunk]
277283

278-
instance = OutputPipelineInstance([ContextCheckingStep()], input_context=input_context)
284+
db_recorder = MagicMock()
285+
instance = OutputPipelineInstance(
286+
[ContextCheckingStep()], input_context=input_context, db_recorder=db_recorder
287+
)
279288

280289
async def mock_stream():
281290
yield create_model_response("test")
@@ -288,7 +297,8 @@ async def test_buffer_flush_on_stream_end(self):
288297
"""Test that buffer is properly flushed when stream ends"""
289298
step = MockOutputPipelineStep("test_step", should_pause=True)
290299
context = MockContext()
291-
instance = OutputPipelineInstance([step], context)
300+
db_recorder = MagicMock()
301+
instance = OutputPipelineInstance([step], context, db_recorder)
292302

293303
async def mock_stream():
294304
yield create_model_response("Hello")

0 commit comments

Comments
 (0)