From 641438bfa52482bfa4d954835dae1bc76aafb4cf Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 27 Nov 2024 13:26:15 +0100 Subject: [PATCH 1/2] Add normalizer instead of abusing LiteLLM adapter There were 2 problems with how our classes were structured (there are more but these two I attempted to solve with this PR): 1) The adapter was supposed to be used in the litellm based providers and was supposed to do the translation using litellm's adapters. But we stuffed way too much logic in them and started leaking the logic to other providers 2) Despite litellm using the openaAI format for input and output, other providers (llama.cpp and soon to be hosted vllm) don't. We need a way to canonicalize them to openAI format. This PR adds a new module called normalizer that takes some of the work from the adapter and is only responsible for changing the requests and replies to the openAI format. This is useful so that our pipelines always work on openAI format internally, both the current input pipeline and the output pipeline. The completion handler now really only does completion (previously it was a really confusing class that did several things) and the adapter is not hidden better in the litellmshim module. To ship the PR faster, there are only two normalizers - OpenAI that just passes the data through and Anthropic that uses the LiteLLM adapter. Next, we'll add a llama.cpp normalizer to get rid of the `` tags and convert them into a properly formatted OpenAI message. --- src/codegate/__init__.py | 4 +- src/codegate/cli.py | 2 +- src/codegate/codegate_logging.py | 2 +- src/codegate/config.py | 2 +- src/codegate/providers/anthropic/adapter.py | 47 ++++------ src/codegate/providers/anthropic/provider.py | 14 +-- src/codegate/providers/base.py | 57 ++++++++---- src/codegate/providers/completion/__init__.py | 5 ++ src/codegate/providers/completion/base.py | 33 ++----- src/codegate/providers/formatting/__init__.py | 5 ++ .../providers/formatting/input_pipeline.py | 10 +-- .../providers/litellmshim/__init__.py | 4 +- src/codegate/providers/litellmshim/adapter.py | 63 ++++++++++++- .../providers/litellmshim/generators.py | 22 +---- .../providers/litellmshim/litellmshim.py | 46 ++-------- src/codegate/providers/llamacpp/adapter.py | 32 ------- .../providers/llamacpp/completion_handler.py | 59 ++++++------ src/codegate/providers/llamacpp/provider.py | 14 +-- src/codegate/providers/normalizer/__init__.py | 6 ++ src/codegate/providers/normalizer/base.py | 58 ++++++++++++ src/codegate/providers/openai/adapter.py | 89 ++++++++++++------- src/codegate/providers/openai/provider.py | 14 +-- src/codegate/server.py | 2 +- tests/providers/anthropic/test_adapter.py | 40 +++------ .../providers/litellmshim/test_litellmshim.py | 36 ++++---- tests/providers/test_registry.py | 61 +++++++++++-- tests/test_logging.py | 8 +- 27 files changed, 415 insertions(+), 320 deletions(-) delete mode 100644 src/codegate/providers/llamacpp/adapter.py create mode 100644 src/codegate/providers/normalizer/__init__.py create mode 100644 src/codegate/providers/normalizer/base.py diff --git a/src/codegate/__init__.py b/src/codegate/__init__.py index 15535fbb..042fba14 100644 --- a/src/codegate/__init__.py +++ b/src/codegate/__init__.py @@ -1,10 +1,10 @@ """Codegate - A Generative AI security gateway.""" -from importlib import metadata import logging as python_logging +from importlib import metadata +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel from codegate.exceptions import ConfigurationError try: diff --git a/src/codegate/cli.py b/src/codegate/cli.py index acd28530..8688947b 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -6,8 +6,8 @@ import click +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config, ConfigurationError -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel from codegate.server import init_app diff --git a/src/codegate/codegate_logging.py b/src/codegate/codegate_logging.py index 9656eadc..a57a1579 100644 --- a/src/codegate/codegate_logging.py +++ b/src/codegate/codegate_logging.py @@ -1,8 +1,8 @@ import datetime -from enum import Enum import json import logging import sys +from enum import Enum from typing import Any, Optional diff --git a/src/codegate/config.py b/src/codegate/config.py index 3d39134c..5dda8b47 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -7,7 +7,7 @@ import yaml -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.exceptions import ConfigurationError from codegate.prompts import PromptConfig diff --git a/src/codegate/providers/anthropic/adapter.py b/src/codegate/providers/anthropic/adapter.py index ee9221d4..01149e65 100644 --- a/src/codegate/providers/anthropic/adapter.py +++ b/src/codegate/providers/anthropic/adapter.py @@ -1,46 +1,29 @@ -from typing import Any, Dict, Optional - -from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse from litellm.adapters.anthropic_adapter import ( AnthropicAdapter as LitellmAnthropicAdapter, ) -from litellm.types.llms.anthropic import AnthropicResponse -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import anthropic_stream_generator, BaseAdapter +from codegate.providers.litellmshim.adapter import ( + LiteLLMAdapterInputNormalizer, + LiteLLMAdapterOutputNormalizer, +) -class AnthropicAdapter(BaseAdapter): +class AnthropicInputNormalizer(LiteLLMAdapterInputNormalizer): """ LiteLLM's adapter class interface is used to translate between the Anthropic data format and the underlying model. The AnthropicAdapter class contains the actual implementation of the interface methods, we just forward the calls to it. """ - def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator): - self.litellm_anthropic_adapter = LitellmAnthropicAdapter() - super().__init__(stream_generator) + def __init__(self): + super().__init__(LitellmAnthropicAdapter()) - def translate_completion_input_params( - self, - completion_request: Dict, - ) -> Optional[ChatCompletionRequest]: - return self.litellm_anthropic_adapter.translate_completion_input_params( - completion_request - ) - - def translate_completion_output_params( - self, response: ModelResponse - ) -> Optional[AnthropicResponse]: - return self.litellm_anthropic_adapter.translate_completion_output_params( - response - ) +class AnthropicOutputNormalizer(LiteLLMAdapterOutputNormalizer): + """ + LiteLLM's adapter class interface is used to translate between the Anthropic data + format and the underlying model. The AnthropicAdapter class contains the actual + implementation of the interface methods, we just forward the calls to it. + """ - def translate_completion_output_params_streaming( - self, completion_stream: Any - ) -> AdapterCompletionStreamWrapper | None: - return ( - self.litellm_anthropic_adapter.translate_completion_output_params_streaming( - completion_stream - ) - ) + def __init__(self): + super().__init__(LitellmAnthropicAdapter()) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 1b39ee07..a16c5921 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -2,16 +2,20 @@ from fastapi import Header, HTTPException, Request +from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim -from codegate.providers.anthropic.adapter import AnthropicAdapter +from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator class AnthropicProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = AnthropicAdapter() - completion_handler = LiteLLmShim(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator) + super().__init__( + AnthropicInputNormalizer(), + AnthropicOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 3fe29cc8..940d8ea8 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -3,11 +3,12 @@ from fastapi import APIRouter from litellm import ModelResponse +from litellm.types.llms.openai import ChatCompletionRequest +from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter - -from ..pipeline.base import SequentialPipelineProcessor +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] @@ -19,14 +20,20 @@ class BaseProvider(ABC): def __init__( self, + input_normalizer: ModelInputNormalizer, + output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, pipeline_processor: Optional[SequentialPipelineProcessor] = None ): self.router = APIRouter() self._completion_handler = completion_handler + self._input_normalizer = input_normalizer + self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor + self._pipeline_response_formatter = \ - PipelineResponseFormatter(completion_handler) + PipelineResponseFormatter(output_normalizer) + self._setup_routes() @abstractmethod @@ -38,8 +45,23 @@ def _setup_routes(self) -> None: def provider_route_name(self) -> str: pass + async def _run_input_pipeline( + self, + normalized_request: ChatCompletionRequest, + ) -> PipelineResult: + if self._pipeline_processor is None: + return PipelineResult(request=normalized_request) + + result = await self._pipeline_processor.process_request(normalized_request) + + # TODO(jakub): handle this by returning a message to the client + if result.error_message: + raise Exception(result.error_message) + + return result + async def complete( - self, data: Dict, api_key: str, + self, data: Dict, api_key: Optional[str], ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration @@ -52,31 +74,28 @@ async def complete( - Execute the completion and translate the response back to the provider-specific format """ - completion_request = self._completion_handler.translate_request(data, api_key) + normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - if self._pipeline_processor is not None: - result = await self._pipeline_processor.process_request(completion_request) - - if result.error_message: - raise Exception(result.error_message) + input_pipeline_result = await self._run_input_pipeline(normalized_request) + if input_pipeline_result.response: + return self._pipeline_response_formatter.handle_pipeline_response( + input_pipeline_result.response, streaming) - if result.response: - return self._pipeline_response_formatter.handle_pipeline_response( - result.response, streaming) - - completion_request = result.request + provider_request = self._input_normalizer.denormalize(input_pipeline_result.request) # Execute the completion and translate the response # This gives us either a single response or a stream of responses # based on the streaming flag - raw_response = await self._completion_handler.execute_completion( - completion_request, + model_response = await self._completion_handler.execute_completion( + provider_request, + api_key=api_key, stream=streaming ) + if not streaming: - return self._completion_handler.translate_response(raw_response) - return self._completion_handler.translate_streaming_response(raw_response) + return self._output_normalizer.denormalize(model_response) + return self._output_normalizer.denormalize_streaming(model_response) def get_routes(self) -> APIRouter: return self.router diff --git a/src/codegate/providers/completion/__init__.py b/src/codegate/providers/completion/__init__.py index e69de29b..80a0fefd 100644 --- a/src/codegate/providers/completion/__init__.py +++ b/src/codegate/providers/completion/__init__.py @@ -0,0 +1,5 @@ +from codegate.providers.completion.base import BaseCompletionHandler + +__all__ = [ + "BaseCompletionHandler", +] diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index 906e20d4..be0d754f 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse @@ -11,39 +11,18 @@ class BaseCompletionHandler(ABC): and creating the streaming response. """ - @abstractmethod - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - """Convert raw request data into a ChatCompletionRequest""" - pass - @abstractmethod async def execute_completion( - self, - request: ChatCompletionRequest, - stream: bool = False + self, + request: ChatCompletionRequest, + api_key: Optional[str], + stream: bool = False, # TODO: remove this param? ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """Execute the completion request""" pass @abstractmethod def create_streaming_response( - self, stream: AsyncIterator[Any] + self, stream: AsyncIterator[Any] ) -> StreamingResponse: pass - - @abstractmethod - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """Convert pipeline response to provider-specific format""" - pass - - @abstractmethod - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """Convert pipeline response to provider-specific format""" - pass - diff --git a/src/codegate/providers/formatting/__init__.py b/src/codegate/providers/formatting/__init__.py index e69de29b..13ba54a4 100644 --- a/src/codegate/providers/formatting/__init__.py +++ b/src/codegate/providers/formatting/__init__.py @@ -0,0 +1,5 @@ +from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter + +__all__ = [ + "PipelineResponseFormatter", +] diff --git a/src/codegate/providers/formatting/input_pipeline.py b/src/codegate/providers/formatting/input_pipeline.py index 6cf54a8d..83b037d0 100644 --- a/src/codegate/providers/formatting/input_pipeline.py +++ b/src/codegate/providers/formatting/input_pipeline.py @@ -5,7 +5,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.pipeline.base import PipelineResponse -from codegate.providers.completion.base import BaseCompletionHandler +from codegate.providers.normalizer.base import ModelOutputNormalizer def _create_stream_end_response(original_response: ModelResponse) -> ModelResponse: @@ -88,9 +88,9 @@ async def _convert_to_stream( class PipelineResponseFormatter: def __init__(self, - completion_handler: BaseCompletionHandler, + output_normalizer: ModelOutputNormalizer, ): - self._completion_handler = completion_handler + self._output_normalizer = output_normalizer def handle_pipeline_response( self, @@ -114,7 +114,7 @@ def handle_pipeline_response( if not streaming: # If we're not streaming, we just return the response translated # to the provider-specific format - return self._completion_handler.translate_response(model_response) + return self._output_normalizer.denormalize(model_response) # If we're streaming, we need to convert the response to a stream first # then feed the stream into the completion handler's conversion method @@ -123,7 +123,7 @@ def handle_pipeline_response( pipeline_response.step_name, pipeline_response.model ) - return self._completion_handler.translate_streaming_response( + return self._output_normalizer.denormalize_streaming( model_response_stream ) diff --git a/src/codegate/providers/litellmshim/__init__.py b/src/codegate/providers/litellmshim/__init__.py index ab470e3c..b2561059 100644 --- a/src/codegate/providers/litellmshim/__init__.py +++ b/src/codegate/providers/litellmshim/__init__.py @@ -1,13 +1,13 @@ from codegate.providers.litellmshim.adapter import BaseAdapter from codegate.providers.litellmshim.generators import ( - anthropic_stream_generator, sse_stream_generator, llamacpp_stream_generator + anthropic_stream_generator, + sse_stream_generator, ) from codegate.providers.litellmshim.litellmshim import LiteLLmShim __all__ = [ "sse_stream_generator", "anthropic_stream_generator", - "llamacpp_stream_generator", "LiteLLmShim", "BaseAdapter", ] diff --git a/src/codegate/providers/litellmshim/adapter.py b/src/codegate/providers/litellmshim/adapter.py index b1c349f0..0c11d2c3 100644 --- a/src/codegate/providers/litellmshim/adapter.py +++ b/src/codegate/providers/litellmshim/adapter.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Optional, Union from litellm import ChatCompletionRequest, ModelResponse from codegate.providers.base import StreamGenerator +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer class BaseAdapter(ABC): @@ -23,7 +24,7 @@ def __init__(self, stream_generator: StreamGenerator): @abstractmethod def translate_completion_input_params( - self, kwargs: Dict + self, kwargs: Dict ) -> Optional[ChatCompletionRequest]: """Convert input parameters to LiteLLM's ChatCompletionRequest format""" pass @@ -35,10 +36,66 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any: @abstractmethod def translate_completion_output_params_streaming( - self, completion_stream: Any + self, completion_stream: Any ) -> Any: """ Convert streaming response from LiteLLM format to a format that can be passed to a stream generator and to the client. """ pass + +class LiteLLMAdapterInputNormalizer(ModelInputNormalizer): + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + Uses an LiteLLM adapter to translate the request data from the native + LLM format to the OpenAI API format used by LiteLLM internally. + """ + return self._adapter.translate_completion_input_params(data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + For LiteLLM, we don't have to de-normalize as the input format is + always ChatCompletionRequest which is a TypedDict which is a Dict + """ + return data + +class LiteLLMAdapterOutputNormalizer(ModelOutputNormalizer): + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + """ + Normalize the output stream. This is a pass-through for liteLLM output normalizer + as the liteLLM output is already in the normalized format. + """ + return model_reply + + def normalize(self, model_reply: Any) -> ModelResponse: + """ + Normalize the output data. This is a pass-through for liteLLM output normalizer + as the liteLLM output is already in the normalized format. + """ + return model_reply + + def denormalize(self, normalized_reply: ModelResponse) -> Any: + """ + Denormalize the output data from the completion function to the format + expected by the client + """ + return self._adapter.translate_completion_output_params(normalized_reply) + + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + """ + Denormalize the output stream from the completion function to the format + expected by the client + """ + return self._adapter.translate_completion_output_params_streaming(normalized_reply) diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index c9ad8fc8..636203cc 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -1,6 +1,5 @@ import json -from typing import Any, AsyncIterator, Iterator -import asyncio +from typing import Any, AsyncIterator from pydantic import BaseModel @@ -37,21 +36,4 @@ async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterato except Exception as e: yield f"event: {event_type}\ndata:{str(e)}\n\n" except Exception as e: - yield f"data: {str(e)}\n\n" - - -async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]: - """OpenAI-style SSE format""" - try: - for chunk in stream: - if hasattr(chunk, "model_dump_json"): - chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) - try: - yield f"data:{json.dumps(chunk)}\n\n" - await asyncio.sleep(0) - except Exception as e: - yield f"data:{str(e)}\n\n" - except Exception as e: - yield f"data: {str(e)}\n\n" - finally: - yield "data: [DONE]\n\n" + yield f"data: {str(e)}\n\n" \ No newline at end of file diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 53f08443..5f9820e1 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -1,10 +1,9 @@ -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse, acompletion -from codegate.providers.base import BaseCompletionHandler -from codegate.providers.litellmshim.adapter import BaseAdapter +from codegate.providers.base import BaseCompletionHandler, StreamGenerator class LiteLLmShim(BaseCompletionHandler): @@ -14,49 +13,20 @@ class LiteLLmShim(BaseCompletionHandler): LiteLLM API. """ - def __init__(self, adapter: BaseAdapter, completion_func=acompletion): - self._adapter = adapter + def __init__(self, stream_generator: StreamGenerator, completion_func=acompletion): + self._stream_generator = stream_generator self._completion_func = completion_func - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - """ - Uses the configured adapter to translate the request data from the native - LLM API format to the OpenAI API format used by LiteLLM internally. - - The OpenAPI format is also what our pipeline expects. - """ - data["api_key"] = api_key - completion_request = self._adapter.translate_completion_input_params(data) - if completion_request is None: - raise Exception("Couldn't translate the request") - return completion_request - - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """ - Convert pipeline or completion response to provider-specific stream - """ - return self._adapter.translate_completion_output_params_streaming(response) - - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """ - Convert pipeline or completion response to provider-specific format - """ - return self._adapter.translate_completion_output_params(response) - async def execute_completion( self, request: ChatCompletionRequest, - stream: bool = False + api_key: Optional[str], + stream: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with LiteLLM's API """ + request["api_key"] = api_key return await self._completion_func(**request) def create_streaming_response( @@ -67,7 +37,7 @@ def create_streaming_response( is the format that FastAPI expects for streaming responses. """ return StreamingResponse( - self._adapter.stream_generator(stream), + self._stream_generator(stream), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", diff --git a/src/codegate/providers/llamacpp/adapter.py b/src/codegate/providers/llamacpp/adapter.py deleted file mode 100644 index b6f9d394..00000000 --- a/src/codegate/providers/llamacpp/adapter.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Any, AsyncIterator, Dict, Optional - -from litellm import ChatCompletionRequest, ModelResponse - -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import llamacpp_stream_generator, BaseAdapter - - -class LlamaCppAdapter(BaseAdapter): - """ - This is just a wrapper around LiteLLM's adapter class interface that passes - through the input and output as-is - LiteLLM's API expects OpenAI's API - format. - """ - def __init__(self, stream_generator: StreamGenerator = llamacpp_stream_generator): - super().__init__(stream_generator) - - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: - try: - return ChatCompletionRequest(**kwargs) - except Exception as e: - raise ValueError(f"Invalid completion parameters: {str(e)}") - - def translate_completion_output_params(self, response: ModelResponse) -> Any: - return response - - def translate_completion_output_params_streaming( - self, completion_stream: AsyncIterator[ModelResponse] - ) -> AsyncIterator[ModelResponse]: - return completion_stream diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index 822947eb..d7553853 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -1,48 +1,39 @@ -from typing import Any, AsyncIterator, Dict, Union +import json +import asyncio +from typing import Any, AsyncIterator, Iterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse -from codegate.providers.base import BaseCompletionHandler -from codegate.providers.llamacpp.adapter import BaseAdapter -from codegate.inference.inference_engine import LlamaCppInferenceEngine from codegate.config import Config +from codegate.inference.inference_engine import LlamaCppInferenceEngine +from codegate.providers.base import BaseCompletionHandler +async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]: + """OpenAI-style SSE format""" + try: + for chunk in stream: + if hasattr(chunk, "model_dump_json"): + chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + try: + yield f"data:{json.dumps(chunk)}\n\n" + await asyncio.sleep(0) + except Exception as e: + yield f"data:{str(e)}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" + finally: + yield "data: [DONE]\n\n" + class LlamaCppCompletionHandler(BaseCompletionHandler): - def __init__(self, adapter: BaseAdapter): - self._adapter = adapter + def __init__(self): self.inference_engine = LlamaCppInferenceEngine() - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - completion_request = self._adapter.translate_completion_input_params( - data) - if completion_request is None: - raise Exception("Couldn't translate the request") - - return ChatCompletionRequest(**completion_request) - - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """ - Convert pipeline or completion response to provider-specific stream - """ - return self._adapter.translate_completion_output_params_streaming(response) - - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """ - Convert pipeline or completion response to provider-specific format - """ - return self._adapter.translate_completion_output_params(response) - async def execute_completion( self, request: ChatCompletionRequest, + api_key: Optional[str], stream: bool = False ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ @@ -63,14 +54,14 @@ async def execute_completion( return response def create_streaming_response( - self, stream: AsyncIterator[Any] + self, stream: Iterator[Any] ) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. """ return StreamingResponse( - self._adapter.stream_generator(stream), + llamacpp_stream_generator(stream), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index a3227085..26291cdc 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -4,14 +4,18 @@ from codegate.providers.base import BaseProvider from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler -from codegate.providers.llamacpp.adapter import LlamaCppAdapter +from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer class LlamaCppProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = LlamaCppAdapter() - completion_handler = LlamaCppCompletionHandler(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LlamaCppCompletionHandler() + super().__init__( + LLamaCppInputNormalizer(), + LLamaCppOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: @@ -30,5 +34,5 @@ async def create_completion( body = await request.body() data = json.loads(body) - stream = await self.complete(data, None) + stream = await self.complete(data, api_key=None) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/normalizer/__init__.py b/src/codegate/providers/normalizer/__init__.py new file mode 100644 index 00000000..6d5ba244 --- /dev/null +++ b/src/codegate/providers/normalizer/__init__.py @@ -0,0 +1,6 @@ +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + +__all__ = [ + "ModelInputNormalizer", + "ModelOutputNormalizer", +] diff --git a/src/codegate/providers/normalizer/base.py b/src/codegate/providers/normalizer/base.py new file mode 100644 index 00000000..625842c9 --- /dev/null +++ b/src/codegate/providers/normalizer/base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union + +from litellm import ChatCompletionRequest, ModelResponse + + +class ModelInputNormalizer(ABC): + """ + The normalizer class is responsible for normalizing the input data + before it is passed to the pipeline. It converts the input data (raw request) + to the format expected by the pipeline. + """ + + @abstractmethod + def normalize(self, data: Dict) -> ChatCompletionRequest: + """Normalize the input data""" + pass + + @abstractmethod + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """Denormalize the input data""" + pass + + +class ModelOutputNormalizer(ABC): + """ + The output normalizer class is responsible for normalizing the output data + from a model to the format expected by the output pipeline. + + The normalize methods are not implemented yet - they will be when we get + around to implementing output pipelines. + """ + + @abstractmethod + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + """Normalize the output data""" + pass + + @abstractmethod + def normalize(self, model_reply: Any) -> ModelResponse: + """Normalize the output data""" + pass + + @abstractmethod + def denormalize(self, normalized_reply: ModelResponse) -> Any: + """Denormalize the output data""" + pass + + @abstractmethod + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + """Denormalize the output data""" + pass diff --git a/src/codegate/providers/openai/adapter.py b/src/codegate/providers/openai/adapter.py index c7f9b6a6..2810dff5 100644 --- a/src/codegate/providers/openai/adapter.py +++ b/src/codegate/providers/openai/adapter.py @@ -1,33 +1,56 @@ -from typing import Any, AsyncIterator, Dict, Optional - -from litellm import ChatCompletionRequest, ModelResponse - -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import sse_stream_generator, BaseAdapter - - -class OpenAIAdapter(BaseAdapter): - """ - This is just a wrapper around LiteLLM's adapter class interface that passes - through the input and output as-is - LiteLLM's API expects OpenAI's API - format. - """ - - def __init__(self, stream_generator: StreamGenerator = sse_stream_generator): - super().__init__(stream_generator) - - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: - try: - return ChatCompletionRequest(**kwargs) - except Exception as e: - raise ValueError(f"Invalid completion parameters: {str(e)}") - - def translate_completion_output_params(self, response: ModelResponse) -> Any: - return response - - def translate_completion_output_params_streaming( - self, completion_stream: AsyncIterator[ModelResponse] - ) -> AsyncIterator[ModelResponse]: - return completion_stream +from typing import Any, Dict + +from litellm import ChatCompletionRequest + +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + + +class OpenAIInputNormalizer(ModelInputNormalizer): + def __init__(self): + super().__init__() + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + No normalizing needed, already OpenAI format + """ + return ChatCompletionRequest(**data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + No denormalizing needed, already OpenAI format + """ + return data + +class OpenAIOutputNormalizer(ModelOutputNormalizer): + def __init__(self): + super().__init__() + + def normalize_streaming( + self, + model_reply: Any, + ) -> Any: + """ + No normalizing needed, already OpenAI format + """ + return model_reply + + def normalize(self, model_reply: Any) -> Any: + """ + No normalizing needed, already OpenAI format + """ + return model_reply + + def denormalize(self, normalized_reply: Any) -> Any: + """ + No denormalizing needed, already OpenAI format + """ + return normalized_reply + + def denormalize_streaming( + self, + normalized_reply: Any, + ) -> Any: + """ + No denormalizing needed, already OpenAI format + """ + return normalized_reply diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 16167c95..cf08e99f 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -3,15 +3,19 @@ from fastapi import Header, HTTPException, Request from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim -from codegate.providers.openai.adapter import OpenAIAdapter +from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator +from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer class OpenAIProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = OpenAIAdapter() - completion_handler = LiteLLmShim(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) + super().__init__( + OpenAIInputNormalizer(), + OpenAIOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: diff --git a/src/codegate/server.py b/src/codegate/server.py index 0db158f7..359425a2 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, FastAPI from codegate import __description__, __version__ -from codegate.pipeline.base import SequentialPipelineProcessor, PipelineStep +from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 8eab6667..0493a5f8 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Dict, List, Union +from typing import List, Union import pytest from litellm import ModelResponse @@ -12,24 +12,24 @@ ) from litellm.types.utils import Delta, StreamingChoices -from codegate.providers.anthropic.adapter import AnthropicAdapter +from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer @pytest.fixture -def adapter(): - return AnthropicAdapter() +def input_normalizer(): + return AnthropicInputNormalizer() -def test_translate_completion_input_params(adapter): +def test_normalize_anthropic_input(input_normalizer): # Test input data completion_request = { "model": "claude-3-haiku-20240307", + "system": "You are an expert code reviewer", "max_tokens": 1024, "stream": True, "messages": [ { "role": "user", - "system": "You are an expert code reviewer", "content": [{"type": "text", "text": "Review this code"}], } ], @@ -37,6 +37,7 @@ def test_translate_completion_input_params(adapter): expected = { "max_tokens": 1024, "messages": [ + {"content": "You are an expert code reviewer", "role": "system"}, {"content": [{"text": "Review this code", "type": "text"}], "role": "user"} ], "model": "claude-3-haiku-20240307", @@ -44,12 +45,16 @@ def test_translate_completion_input_params(adapter): } # Get translation - result = adapter.translate_completion_input_params(completion_request) + result = input_normalizer.normalize(completion_request) assert result == expected +@pytest.fixture +def output_normalizer(): + return AnthropicOutputNormalizer() + @pytest.mark.asyncio -async def test_translate_completion_output_params_streaming(adapter): +async def test_normalize_anthropic_output_stream(output_normalizer): # Test stream data async def mock_stream(): messages = [ @@ -129,7 +134,7 @@ async def mock_stream(): dict(type="message_stop"), ] - stream = adapter.translate_completion_output_params_streaming(mock_stream()) + stream = output_normalizer.denormalize_streaming(mock_stream()) assert isinstance(stream, AnthropicStreamWrapper) # just so that we can zip over the expected chunks @@ -139,20 +144,3 @@ async def mock_stream(): for chunk, expected_chunk in zip(stream_list, expected): assert chunk == expected_chunk - - -def test_stream_generator_initialization(adapter): - # Verify the default stream generator is set - from codegate.providers.litellmshim import anthropic_stream_generator - - assert adapter.stream_generator == anthropic_stream_generator - - -def test_custom_stream_generator(): - # Test that we can inject a custom stream generator - async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]: - async for chunk in stream: - yield "custom: " + str(chunk) - - adapter = AnthropicAdapter(stream_generator=custom_generator) - assert adapter.stream_generator == custom_generator diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 0e524220..442177df 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -5,7 +5,7 @@ from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse -from codegate.providers.litellmshim import BaseAdapter, LiteLLmShim +from codegate.providers.litellmshim import BaseAdapter, LiteLLmShim, sse_stream_generator class MockAdapter(BaseAdapter): @@ -38,24 +38,17 @@ async def modified_stream(): return modified_stream() -@pytest.fixture -def mock_adapter(): - return MockAdapter() - - -@pytest.fixture -def litellm_shim(mock_adapter): - return LiteLLmShim(mock_adapter) - - @pytest.mark.asyncio -async def test_complete_non_streaming(litellm_shim, mock_adapter): +async def test_complete_non_streaming(): # Mock response mock_response = ModelResponse(id="123", choices=[{"text": "test response"}]) mock_completion = AsyncMock(return_value=mock_response) # Create shim with mocked completion - litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + litellm_shim = LiteLLmShim( + stream_generator=sse_stream_generator, + completion_func=mock_completion + ) # Test data data = { @@ -64,7 +57,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter): } # Execute - result = await litellm_shim.execute_completion(data) + result = await litellm_shim.execute_completion(data, api_key=None) # Verify assert result == mock_response @@ -81,8 +74,10 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: yield ModelResponse(id="123", choices=[{"text": "chunk2"}]) mock_completion = AsyncMock(return_value=mock_stream()) - mock_adapter = MockAdapter() - litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + litellm_shim = LiteLLmShim( + stream_generator=sse_stream_generator, + completion_func=mock_completion + ) # Test data data = { @@ -92,7 +87,9 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: } # Execute - result_stream = await litellm_shim.execute_completion(data) + result_stream = await litellm_shim.execute_completion( + ChatCompletionRequest(**data), + api_key=None) # Verify stream contents and adapter processing chunks = [] @@ -112,7 +109,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: @pytest.mark.asyncio -async def test_create_streaming_response(litellm_shim): +async def test_create_streaming_response(): # Create a simple async generator that we know works async def mock_stream_gen(): for msg in ["Hello", "World"]: @@ -121,6 +118,7 @@ async def mock_stream_gen(): # Create and verify the generator generator = mock_stream_gen() + litellm_shim = LiteLLmShim(stream_generator=sse_stream_generator) response = litellm_shim.create_streaming_response(generator) # Verify response metadata @@ -128,4 +126,4 @@ async def mock_stream_gen(): assert response.status_code == 200 assert response.headers["Cache-Control"] == "no-cache" assert response.headers["Connection"] == "keep-alive" - assert response.headers["Transfer-Encoding"] == "chunked" \ No newline at end of file + assert response.headers["Transfer-Encoding"] == "chunked" diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 209a3651..29e5a266 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -1,10 +1,21 @@ -from typing import Any, AsyncIterator, Dict +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, + Iterator, + Optional, + Union, +) import pytest from fastapi import FastAPI from fastapi.responses import StreamingResponse +from litellm import ChatCompletionRequest, ModelResponse from codegate.providers.base import BaseCompletionHandler, BaseProvider +from codegate.providers.normalizer import ModelInputNormalizer, ModelOutputNormalizer from codegate.providers.registry import ProviderRegistry @@ -26,7 +37,8 @@ def translate_streaming_response( def execute_completion( self, - request: Any, + request: ChatCompletionRequest, + api_key: Optional[str], stream: bool = False, ) -> Any: pass @@ -37,8 +49,41 @@ def create_streaming_response( ) -> StreamingResponse: return StreamingResponse(stream) +class MockInputNormalizer(ModelInputNormalizer): + def normalize(self, data: Dict) -> Dict: + return data + + def denormalize(self, data: Dict) -> Dict: + return data + +class MockOutputNormalizer(ModelOutputNormalizer): + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + pass + + def normalize(self, model_reply: Any) -> ModelResponse: + pass + + def denormalize(self, normalized_reply: ModelResponse) -> Any: + pass + + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + pass class MockProvider(BaseProvider): + def __init__( + self, + ): + super().__init__( + MockInputNormalizer(), + MockOutputNormalizer(), + MockCompletionHandler(), + None) @property def provider_route_name(self) -> str: @@ -65,24 +110,24 @@ def registry(app): return ProviderRegistry(app) -def test_add_provider(registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_add_provider(registry): + provider = MockProvider() registry.add_provider("test", provider) assert "test" in registry.providers assert registry.providers["test"] == provider -def test_get_provider(registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_get_provider(registry): + provider = MockProvider() registry.add_provider("test", provider) assert registry.get_provider("test") == provider assert registry.get_provider("nonexistent") is None -def test_provider_routes_added(app, registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_provider_routes_added(app, registry): + provider = MockProvider() registry.add_provider("test", provider) routes = [route for route in app.routes if route.path == "/mock_provider/test"] diff --git a/tests/test_logging.py b/tests/test_logging.py index 97f906b8..d2160de9 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -2,7 +2,13 @@ import logging from io import StringIO -from codegate.codegate_logging import JSONFormatter, TextFormatter, setup_logging, LogFormat, LogLevel +from codegate.codegate_logging import ( + JSONFormatter, + LogFormat, + LogLevel, + TextFormatter, + setup_logging, +) def test_json_formatter(): From bffce8ac195e911d24e120e2e3052c59b03b9c51 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 27 Nov 2024 15:35:30 +0100 Subject: [PATCH 2/2] Run make all to reformat the code --- scripts/import_packages.py | 21 ++++--- src/codegate/config.py | 4 +- src/codegate/pipeline/base.py | 18 +++--- src/codegate/pipeline/version/version.py | 4 +- src/codegate/prompts.py | 4 +- src/codegate/providers/anthropic/adapter.py | 1 + src/codegate/providers/base.py | 23 +++---- src/codegate/providers/completion/base.py | 6 +- .../providers/formatting/input_pipeline.py | 62 +++++++------------ src/codegate/providers/litellmshim/adapter.py | 18 +++--- .../providers/litellmshim/generators.py | 2 +- .../providers/litellmshim/litellmshim.py | 4 +- .../providers/llamacpp/completion_handler.py | 10 +-- src/codegate/providers/openai/adapter.py | 9 +-- src/codegate/providers/openai/provider.py | 4 +- tests/providers/anthropic/test_adapter.py | 3 +- .../providers/litellmshim/test_litellmshim.py | 10 ++- tests/providers/test_registry.py | 21 ++++--- tests/test_cli_prompts.py | 4 +- tests/test_prompts.py | 4 +- 20 files changed, 102 insertions(+), 130 deletions(-) diff --git a/scripts/import_packages.py b/scripts/import_packages.py index c66a271f..f980d17d 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -13,8 +13,7 @@ class PackageImporter: def __init__(self): self.client = weaviate.WeaviateClient( embedded_options=EmbeddedOptions( - persistence_data_path="./weaviate_data", - grpc_port=50052 + persistence_data_path="./weaviate_data", grpc_port=50052 ) ) self.json_files = [ @@ -46,13 +45,13 @@ def generate_vector_string(self, package): "npm": "JavaScript package available on NPM", "go": "Go package", "crates": "Rust package available on Crates", - "java": "Java package" + "java": "Java package", } status_messages = { "archived": "However, this package is found to be archived and no longer maintained.", "deprecated": "However, this package is found to be deprecated and no longer " "recommended for use.", - "malicious": "However, this package is found to be malicious." + "malicious": "However, this package is found to be malicious.", } vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" @@ -75,8 +74,9 @@ async def add_data(self): packages_dict = { f"{package.properties['name']}/{package.properties['type']}": { "status": package.properties["status"], - "description": package.properties["description"] - } for package in existing_packages + "description": package.properties["description"], + } + for package in existing_packages } for json_file in self.json_files: @@ -85,12 +85,12 @@ async def add_data(self): packages_to_insert = [] for line in f: package = json.loads(line) - package["status"] = json_file.split('/')[-1].split('.')[0] + package["status"] = json_file.split("/")[-1].split(".")[0] key = f"{package['name']}/{package['type']}" if key in packages_dict and packages_dict[key] == { "status": package["status"], - "description": package["description"] + "description": package["description"], }: print("Package already exists", key) continue @@ -102,8 +102,9 @@ async def add_data(self): # Synchronous batch insert after preparing all data with collection.batch.dynamic() as batch: for package, vector in packages_to_insert: - batch.add_object(properties=package, vector=vector, - uuid=generate_uuid5(package)) + batch.add_object( + properties=package, vector=vector, uuid=generate_uuid5(package) + ) async def run_import(self): self.setup_schema() diff --git a/src/codegate/config.py b/src/codegate/config.py index 5dda8b47..e63e5fc1 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -52,9 +52,7 @@ def __post_init__(self) -> None: @staticmethod def _load_default_prompts() -> PromptConfig: """Load default prompts from prompts/default.yaml.""" - default_prompts_path = ( - Path(__file__).parent.parent.parent / "prompts" / "default.yaml" - ) + default_prompts_path = Path(__file__).parent.parent.parent / "prompts" / "default.yaml" try: return PromptConfig.from_file(default_prompts_path) except Exception as e: diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index d0e77602..b8875dda 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -14,6 +14,7 @@ class CodeSnippet: language: The programming language identifier (e.g., 'python', 'javascript') code: The actual code content """ + language: str code: str @@ -24,6 +25,7 @@ def __post_init__(self): raise ValueError("Code must not be empty") self.language = self.language.strip().lower() + @dataclass class PipelineContext: code_snippets: List[CodeSnippet] = field(default_factory=list) @@ -35,13 +37,16 @@ def add_code_snippet(self, snippet: CodeSnippet): def get_snippets_by_language(self, language: str) -> List[CodeSnippet]: return [s for s in self.code_snippets if s.language.lower() == language.lower()] + @dataclass class PipelineResponse: """Response generated by a pipeline step""" + content: str step_name: str # The name of the pipeline step that generated this response model: str # Taken from the original request's model field + @dataclass class PipelineResult: """ @@ -49,6 +54,7 @@ class PipelineResult: Either contains a modified request to continue processing, or a response to return to the client. """ + request: Optional[ChatCompletionRequest] = None response: Optional[PipelineResponse] = None error_message: Optional[str] = None @@ -79,8 +85,8 @@ def name(self) -> str: @staticmethod def get_last_user_message( - request: ChatCompletionRequest, - ) -> Optional[tuple[str, int]]: + request: ChatCompletionRequest, + ) -> Optional[tuple[str, int]]: """ Get the last user message and its index from the request. @@ -122,9 +128,7 @@ def get_last_user_message( @abstractmethod async def process( - self, - request: ChatCompletionRequest, - context: PipelineContext + self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: """Process a request and return either modified request or response stream""" pass @@ -135,8 +139,8 @@ def __init__(self, pipeline_steps: List[PipelineStep]): self.pipeline_steps = pipeline_steps async def process_request( - self, - request: ChatCompletionRequest, + self, + request: ChatCompletionRequest, ) -> PipelineResult: """ Process a request through all pipeline steps diff --git a/src/codegate/pipeline/version/version.py b/src/codegate/pipeline/version/version.py index 314c831f..9f809ace 100644 --- a/src/codegate/pipeline/version/version.py +++ b/src/codegate/pipeline/version/version.py @@ -23,9 +23,7 @@ def name(self) -> str: return "codegate-version" async def process( - self, - request: ChatCompletionRequest, - context: PipelineContext + self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: """ Checks if the last user message contains "codegate-version" and diff --git a/src/codegate/prompts.py b/src/codegate/prompts.py index a656155d..63405a08 100644 --- a/src/codegate/prompts.py +++ b/src/codegate/prompts.py @@ -44,9 +44,7 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig": # Validate all values are strings for key, value in prompt_data.items(): if not isinstance(value, str): - raise ConfigurationError( - f"Prompt '{key}' must be a string, got {type(value)}" - ) + raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}") return cls(prompts=prompt_data) except yaml.YAMLError as e: diff --git a/src/codegate/providers/anthropic/adapter.py b/src/codegate/providers/anthropic/adapter.py index 01149e65..6f89bd4a 100644 --- a/src/codegate/providers/anthropic/adapter.py +++ b/src/codegate/providers/anthropic/adapter.py @@ -18,6 +18,7 @@ class AnthropicInputNormalizer(LiteLLMAdapterInputNormalizer): def __init__(self): super().__init__(LitellmAnthropicAdapter()) + class AnthropicOutputNormalizer(LiteLLMAdapterOutputNormalizer): """ LiteLLM's adapter class interface is used to translate between the Anthropic data diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 940d8ea8..940d93b2 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -12,6 +12,7 @@ StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] + class BaseProvider(ABC): """ The provider class is responsible for defining the API routes and @@ -23,7 +24,7 @@ def __init__( input_normalizer: ModelInputNormalizer, output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, - pipeline_processor: Optional[SequentialPipelineProcessor] = None + pipeline_processor: Optional[SequentialPipelineProcessor] = None, ): self.router = APIRouter() self._completion_handler = completion_handler @@ -31,8 +32,7 @@ def __init__( self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor - self._pipeline_response_formatter = \ - PipelineResponseFormatter(output_normalizer) + self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) self._setup_routes() @@ -46,8 +46,8 @@ def provider_route_name(self) -> str: pass async def _run_input_pipeline( - self, - normalized_request: ChatCompletionRequest, + self, + normalized_request: ChatCompletionRequest, ) -> PipelineResult: if self._pipeline_processor is None: return PipelineResult(request=normalized_request) @@ -61,8 +61,10 @@ async def _run_input_pipeline( return result async def complete( - self, data: Dict, api_key: Optional[str], - ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: + self, + data: Dict, + api_key: Optional[str], + ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration @@ -80,7 +82,8 @@ async def complete( input_pipeline_result = await self._run_input_pipeline(normalized_request) if input_pipeline_result.response: return self._pipeline_response_formatter.handle_pipeline_response( - input_pipeline_result.response, streaming) + input_pipeline_result.response, streaming + ) provider_request = self._input_normalizer.denormalize(input_pipeline_result.request) @@ -88,9 +91,7 @@ async def complete( # This gives us either a single response or a stream of responses # based on the streaming flag model_response = await self._completion_handler.execute_completion( - provider_request, - api_key=api_key, - stream=streaming + provider_request, api_key=api_key, stream=streaming ) if not streaming: diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index be0d754f..2bba9bc2 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -16,13 +16,11 @@ async def execute_completion( self, request: ChatCompletionRequest, api_key: Optional[str], - stream: bool = False, # TODO: remove this param? + stream: bool = False, # TODO: remove this param? ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """Execute the completion request""" pass @abstractmethod - def create_streaming_response( - self, stream: AsyncIterator[Any] - ) -> StreamingResponse: + def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: pass diff --git a/src/codegate/providers/formatting/input_pipeline.py b/src/codegate/providers/formatting/input_pipeline.py index 83b037d0..01d5ef7a 100644 --- a/src/codegate/providers/formatting/input_pipeline.py +++ b/src/codegate/providers/formatting/input_pipeline.py @@ -14,24 +14,21 @@ def _create_stream_end_response(original_response: ModelResponse) -> ModelRespon id=original_response.id, choices=[ StreamingChoices( - finish_reason="stop", - index=0, - delta=Delta( - content="", - role=None - ), - logprobs=None + finish_reason="stop", index=0, delta=Delta(content="", role=None), logprobs=None ) ], created=original_response.created, model=original_response.model, - object="chat.completion.chunk" + object="chat.completion.chunk", ) def _create_model_response( - content: str, step_name: str, model: str, streaming: bool, - ) -> ModelResponse: + content: str, + step_name: str, + model: str, + streaming: bool, +) -> ModelResponse: """ Create a ModelResponse in either streaming or non-streaming format This is required because the ModelResponse format is different for streaming @@ -47,33 +44,28 @@ def _create_model_response( StreamingChoices( finish_reason=None, index=0, - delta=Delta( - content=content, - role="assistant" - ), - logprobs=None + delta=Delta(content=content, role="assistant"), + logprobs=None, ) ], created=created, model=model, - object="chat.completion.chunk" + object="chat.completion.chunk", ) else: return ModelResponse( id=response_id, - choices=[{ - "text": content, - "index": 0, - "finish_reason": None - }], + choices=[{"text": content, "index": 0, "finish_reason": None}], created=created, - model=model + model=model, ) async def _convert_to_stream( - content: str, step_name: str, model: str, - ) -> AsyncIterator[ModelResponse]: + content: str, + step_name: str, + model: str, +) -> AsyncIterator[ModelResponse]: """ Converts a single completion response, provided by our pipeline as a shortcut to a streaming response. The streaming response has two chunks: the first @@ -87,15 +79,14 @@ async def _convert_to_stream( class PipelineResponseFormatter: - def __init__(self, - output_normalizer: ModelOutputNormalizer, - ): + def __init__( + self, + output_normalizer: ModelOutputNormalizer, + ): self._output_normalizer = output_normalizer def handle_pipeline_response( - self, - pipeline_response: PipelineResponse, - streaming: bool + self, pipeline_response: PipelineResponse, streaming: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Convert pipeline response to appropriate format based on streaming flag @@ -109,7 +100,7 @@ def handle_pipeline_response( pipeline_response.content, pipeline_response.step_name, pipeline_response.model, - streaming=streaming + streaming=streaming, ) if not streaming: # If we're not streaming, we just return the response translated @@ -119,11 +110,6 @@ def handle_pipeline_response( # If we're streaming, we need to convert the response to a stream first # then feed the stream into the completion handler's conversion method model_response_stream = _convert_to_stream( - pipeline_response.content, - pipeline_response.step_name, - pipeline_response.model + pipeline_response.content, pipeline_response.step_name, pipeline_response.model ) - return self._output_normalizer.denormalize_streaming( - model_response_stream - ) - + return self._output_normalizer.denormalize_streaming(model_response_stream) diff --git a/src/codegate/providers/litellmshim/adapter.py b/src/codegate/providers/litellmshim/adapter.py index 0c11d2c3..c0b1a6a9 100644 --- a/src/codegate/providers/litellmshim/adapter.py +++ b/src/codegate/providers/litellmshim/adapter.py @@ -23,9 +23,7 @@ def __init__(self, stream_generator: StreamGenerator): self.stream_generator = stream_generator @abstractmethod - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: + def translate_completion_input_params(self, kwargs: Dict) -> Optional[ChatCompletionRequest]: """Convert input parameters to LiteLLM's ChatCompletionRequest format""" pass @@ -35,15 +33,14 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any: pass @abstractmethod - def translate_completion_output_params_streaming( - self, completion_stream: Any - ) -> Any: + def translate_completion_output_params_streaming(self, completion_stream: Any) -> Any: """ Convert streaming response from LiteLLM format to a format that can be passed to a stream generator and to the client. """ pass + class LiteLLMAdapterInputNormalizer(ModelInputNormalizer): def __init__(self, adapter: BaseAdapter): self._adapter = adapter @@ -62,13 +59,14 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict: """ return data + class LiteLLMAdapterOutputNormalizer(ModelOutputNormalizer): def __init__(self, adapter: BaseAdapter): self._adapter = adapter def normalize_streaming( - self, - model_reply: Union[AsyncIterable[Any], Iterable[Any]], + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: """ Normalize the output stream. This is a pass-through for liteLLM output normalizer @@ -91,8 +89,8 @@ def denormalize(self, normalized_reply: ModelResponse) -> Any: return self._adapter.translate_completion_output_params(normalized_reply) def denormalize_streaming( - self, - normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], ) -> Union[AsyncIterator[Any], Iterator[Any]]: """ Denormalize the output stream from the completion function to the format diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index 636203cc..306f1900 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -36,4 +36,4 @@ async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterato except Exception as e: yield f"event: {event_type}\ndata:{str(e)}\n\n" except Exception as e: - yield f"data: {str(e)}\n\n" \ No newline at end of file + yield f"data: {str(e)}\n\n" diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 5f9820e1..1b4dcdf5 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -29,9 +29,7 @@ async def execute_completion( request["api_key"] = api_key return await self._completion_func(**request) - def create_streaming_response( - self, stream: AsyncIterator[Any] - ) -> StreamingResponse: + def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index d7553853..40046975 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -26,15 +26,13 @@ async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str] finally: yield "data: [DONE]\n\n" + class LlamaCppCompletionHandler(BaseCompletionHandler): def __init__(self): self.inference_engine = LlamaCppInferenceEngine() async def execute_completion( - self, - request: ChatCompletionRequest, - api_key: Optional[str], - stream: bool = False + self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with inference engine API @@ -53,9 +51,7 @@ async def execute_completion( **request) return response - def create_streaming_response( - self, stream: Iterator[Any] - ) -> StreamingResponse: + def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. diff --git a/src/codegate/providers/openai/adapter.py b/src/codegate/providers/openai/adapter.py index 2810dff5..b5f4565a 100644 --- a/src/codegate/providers/openai/adapter.py +++ b/src/codegate/providers/openai/adapter.py @@ -21,13 +21,14 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict: """ return data + class OpenAIOutputNormalizer(ModelOutputNormalizer): def __init__(self): super().__init__() def normalize_streaming( - self, - model_reply: Any, + self, + model_reply: Any, ) -> Any: """ No normalizing needed, already OpenAI format @@ -47,8 +48,8 @@ def denormalize(self, normalized_reply: Any) -> Any: return normalized_reply def denormalize_streaming( - self, - normalized_reply: Any, + self, + normalized_reply: Any, ) -> Any: """ No denormalizing needed, already OpenAI format diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index cf08e99f..6d1e6c1d 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -34,9 +34,7 @@ async def create_completion( authorization: str = Header(..., description="Bearer token"), ): if not authorization.startswith("Bearer "): - raise HTTPException( - status_code=401, detail="Invalid authorization header" - ) + raise HTTPException(status_code=401, detail="Invalid authorization header") api_key = authorization.split(" ")[1] body = await request.body() diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 0493a5f8..9bb81e54 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -38,7 +38,7 @@ def test_normalize_anthropic_input(input_normalizer): "max_tokens": 1024, "messages": [ {"content": "You are an expert code reviewer", "role": "system"}, - {"content": [{"text": "Review this code", "type": "text"}], "role": "user"} + {"content": [{"text": "Review this code", "type": "text"}], "role": "user"}, ], "model": "claude-3-haiku-20240307", "stream": True, @@ -48,6 +48,7 @@ def test_normalize_anthropic_input(input_normalizer): result = input_normalizer.normalize(completion_request) assert result == expected + @pytest.fixture def output_normalizer(): return AnthropicOutputNormalizer() diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 442177df..73889a34 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -46,8 +46,7 @@ async def test_complete_non_streaming(): # Create shim with mocked completion litellm_shim = LiteLLmShim( - stream_generator=sse_stream_generator, - completion_func=mock_completion + stream_generator=sse_stream_generator, completion_func=mock_completion ) # Test data @@ -75,8 +74,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: mock_completion = AsyncMock(return_value=mock_stream()) litellm_shim = LiteLLmShim( - stream_generator=sse_stream_generator, - completion_func=mock_completion + stream_generator=sse_stream_generator, completion_func=mock_completion ) # Test data @@ -88,8 +86,8 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: # Execute result_stream = await litellm_shim.execute_completion( - ChatCompletionRequest(**data), - api_key=None) + ChatCompletionRequest(**data), api_key=None + ) # Verify stream contents and adapter processing chunks = [] diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 29e5a266..8c957f13 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -49,6 +49,7 @@ def create_streaming_response( ) -> StreamingResponse: return StreamingResponse(stream) + class MockInputNormalizer(ModelInputNormalizer): def normalize(self, data: Dict) -> Dict: return data @@ -56,10 +57,11 @@ def normalize(self, data: Dict) -> Dict: def denormalize(self, data: Dict) -> Dict: return data + class MockOutputNormalizer(ModelOutputNormalizer): def normalize_streaming( - self, - model_reply: Union[AsyncIterable[Any], Iterable[Any]], + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: pass @@ -70,24 +72,23 @@ def denormalize(self, normalized_reply: ModelResponse) -> Any: pass def denormalize_streaming( - self, - normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], ) -> Union[AsyncIterator[Any], Iterator[Any]]: pass + class MockProvider(BaseProvider): def __init__( - self, + self, ): super().__init__( - MockInputNormalizer(), - MockOutputNormalizer(), - MockCompletionHandler(), - None) + MockInputNormalizer(), MockOutputNormalizer(), MockCompletionHandler(), None + ) @property def provider_route_name(self) -> str: - return 'mock_provider' + return "mock_provider" def _setup_routes(self) -> None: @self.router.get(f"/{self.provider_route_name}/test") diff --git a/tests/test_cli_prompts.py b/tests/test_cli_prompts.py index 88c743f6..2b5029a8 100644 --- a/tests/test_cli_prompts.py +++ b/tests/test_cli_prompts.py @@ -72,9 +72,7 @@ def test_serve_with_prompts(temp_prompts_file): """Test the serve command with prompts file.""" runner = CliRunner() # Use --help to avoid actually starting the server - result = runner.invoke( - cli, ["serve", "--prompts", str(temp_prompts_file), "--help"] - ) + result = runner.invoke(cli, ["serve", "--prompts", str(temp_prompts_file), "--help"]) assert result.exit_code == 0 assert "Path to YAML prompts file" in result.output diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 5fef36b0..e28863f0 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -147,9 +147,7 @@ def test_environment_variable_override(temp_env_prompts_file, monkeypatch): assert config.prompts.another_env == "Another environment prompt" -def test_cli_override_takes_precedence( - temp_prompts_file, temp_env_prompts_file, monkeypatch -): +def test_cli_override_takes_precedence(temp_prompts_file, temp_env_prompts_file, monkeypatch): """Test that CLI prompts override config and environment.""" # Set environment variable monkeypatch.setenv("CODEGATE_PROMPTS_FILE", str(temp_env_prompts_file))