From 3f3ce71ce37c35547aa3c7e97711495336304b7e Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Mon, 17 Feb 2025 17:51:41 +0200 Subject: [PATCH] Record output from non-streaming responses in DB Closes: #938 We were not recording the output of the non-streamed responses in the DB. In our API we're filtering out all the conversations that don't contain an answer. Since we were not recording non-streamed responses we were filtering out the entire conversation if it was non-streamed. Note: This still doesn't add an output pipeline for non-streamed responses. --- src/codegate/providers/base.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 3edced69..d22afcc0 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -133,12 +133,23 @@ async def _run_output_stream_pipeline( denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream) return denormalized_stream - def _run_output_pipeline( + async def _run_output_pipeline( self, - normalized_response: ModelResponse, + input_context: PipelineContext, + model_response: Any, ) -> ModelResponse: - # we don't have a pipeline for non-streamed output yet - return normalized_response + """ + Run the output pipeline for a single response. + + For the moment we don't have a pipeline for non-streamed output, so we + just normalize the response and record the context. It is done here to match + the behaviour of the streaming pipeline. + """ + normalized_response = self._output_normalizer.normalize(model_response) + input_context.add_output(normalized_response) + await self._db_recorder.record_context(input_context) + output_result = self._output_normalizer.denormalize(normalized_response) + return output_result async def _run_input_pipeline( self, @@ -263,10 +274,7 @@ async def complete( is_fim_request=is_fim_request, ) if not streaming: - normalized_response = self._output_normalizer.normalize(model_response) - pipeline_output = self._run_output_pipeline(normalized_response) - await self._db_recorder.record_context(input_pipeline_result.context) - return self._output_normalizer.denormalize(pipeline_output) + return await self._run_output_pipeline(input_pipeline_result.context, model_response) pipeline_output_stream = await self._run_output_stream_pipeline( input_pipeline_result.context, model_response, is_fim_request=is_fim_request # type: ignore