diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py index 253e2970..da7e6537 100644 --- a/src/codegate/providers/anthropic/completion_handler.py +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -15,6 +15,7 @@ async def execute_completion( request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False, + is_fim_request: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API. @@ -30,4 +31,4 @@ async def execute_completion( model_in_request = request["model"] if not model_in_request.startswith("anthropic/"): request["model"] = f"anthropic/{model_in_request}" - return await super().execute_completion(request, api_key, stream) + return await super().execute_completion(request, api_key, stream, is_fim_request) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 7a597dca..8cc42f59 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -161,7 +161,7 @@ async def complete( # This gives us either a single response or a stream of responses # based on the streaming flag model_response = await self._completion_handler.execute_completion( - provider_request, api_key=api_key, stream=streaming + provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request ) if not streaming: normalized_response = self._output_normalizer.normalize(model_response) diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index 2bba9bc2..1e86129c 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -17,6 +17,7 @@ async def execute_completion( request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False, # TODO: remove this param? + is_fim_request: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """Execute the completion request""" pass diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 1b4dcdf5..1de5bb28 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Optional, Union +from typing import Any, AsyncIterator, Callable, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse, acompletion @@ -13,20 +13,33 @@ class LiteLLmShim(BaseCompletionHandler): LiteLLM API. """ - def __init__(self, stream_generator: StreamGenerator, completion_func=acompletion): + def __init__( + self, + stream_generator: StreamGenerator, + completion_func: Callable = acompletion, + fim_completion_func: Optional[Callable] = None, + ): self._stream_generator = stream_generator self._completion_func = completion_func + # Use the same function for FIM completion if one is not specified + if fim_completion_func is None: + self._fim_completion_func = completion_func + else: + self._fim_completion_func = fim_completion_func async def execute_completion( self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False, + is_fim_request: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with LiteLLM's API """ request["api_key"] = api_key + if is_fim_request: + return await self._fim_completion_func(**request) return await self._completion_func(**request) def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index 9aa1c66d..f5e6fc1d 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -47,14 +47,18 @@ def __init__(self): self.inference_engine = LlamaCppInferenceEngine() async def execute_completion( - self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False + self, + request: ChatCompletionRequest, + api_key: Optional[str], + stream: bool = False, + is_fim_request: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with inference engine API """ model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf" - if "prompt" in request: + if is_fim_request: response = await self.inference_engine.complete( model_path, Config.get_config().chat_model_n_ctx, diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 05c8f720..d2e710af 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -2,6 +2,7 @@ from typing import Optional from fastapi import Header, HTTPException, Request +from litellm import atext_completion from codegate.config import Config from codegate.providers.base import BaseProvider, SequentialPipelineProcessor @@ -15,7 +16,9 @@ def __init__( pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, ): - completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) + completion_handler = LiteLLmShim( + stream_generator=sse_stream_generator, fim_completion_func=atext_completion + ) super().__init__( VLLMInputNormalizer(), VLLMOutputNormalizer(),