Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add FIM functionalty for VLLM provider #132

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/codegate/providers/anthropic/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def execute_completion(
request: ChatCompletionRequest,
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.
Expand All @@ -30,4 +31,4 @@ async def execute_completion(
model_in_request = request["model"]
if not model_in_request.startswith("anthropic/"):
request["model"] = f"anthropic/{model_in_request}"
return await super().execute_completion(request, api_key, stream)
return await super().execute_completion(request, api_key, stream, is_fim_request)
2 changes: 1 addition & 1 deletion src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def complete(
# This gives us either a single response or a stream of responses
# based on the streaming flag
model_response = await self._completion_handler.execute_completion(
provider_request, api_key=api_key, stream=streaming
provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request
)
if not streaming:
normalized_response = self._output_normalizer.normalize(model_response)
Expand Down
1 change: 1 addition & 0 deletions src/codegate/providers/completion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def execute_completion(
request: ChatCompletionRequest,
api_key: Optional[str],
stream: bool = False, # TODO: remove this param?
is_fim_request: bool = False,
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""Execute the completion request"""
pass
Expand Down
17 changes: 15 additions & 2 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Optional, Union
from typing import Any, AsyncIterator, Callable, Optional, Union

from fastapi.responses import StreamingResponse
from litellm import ChatCompletionRequest, ModelResponse, acompletion
Expand All @@ -13,20 +13,33 @@ class LiteLLmShim(BaseCompletionHandler):
LiteLLM API.
"""

def __init__(self, stream_generator: StreamGenerator, completion_func=acompletion):
def __init__(
self,
stream_generator: StreamGenerator,
completion_func: Callable = acompletion,
fim_completion_func: Optional[Callable] = None,
):
self._stream_generator = stream_generator
self._completion_func = completion_func
# Use the same function for FIM completion if one is not specified
if fim_completion_func is None:
self._fim_completion_func = completion_func
else:
self._fim_completion_func = fim_completion_func

async def execute_completion(
self,
request: ChatCompletionRequest,
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Execute the completion request with LiteLLM's API
"""
request["api_key"] = api_key
if is_fim_request:
return await self._fim_completion_func(**request)
return await self._completion_func(**request)

def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
Expand Down
8 changes: 6 additions & 2 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ def __init__(self):
self.inference_engine = LlamaCppInferenceEngine()

async def execute_completion(
self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False
self,
request: ChatCompletionRequest,
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Execute the completion request with inference engine API
"""
model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf"

if "prompt" in request:
if is_fim_request:
response = await self.inference_engine.complete(
model_path,
Config.get_config().chat_model_n_ctx,
Expand Down
5 changes: 4 additions & 1 deletion src/codegate/providers/vllm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from fastapi import Header, HTTPException, Request
from litellm import atext_completion

from codegate.config import Config
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
Expand All @@ -15,7 +16,9 @@ def __init__(
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
):
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
completion_handler = LiteLLmShim(
stream_generator=sse_stream_generator, fim_completion_func=atext_completion
)
Comment on lines +19 to +21
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to specify litellm to use atext_completion instead of acompletion when is FIM. Continue give us prompt instead of messages but atext_completion is able to handle it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downside of using atext_completion is that it returns a TextCompletionResponse instead of a ModelResponse like our regular acompletion. I checked and the parameters of both objects are almost identical, hopefully it doesn't cause too much problems

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that have an impact on the pipeline processing @jhrozek ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do some testing. The pipeline does expect a ModelResponse maybe we could normalize TextCompletion into ModelResponse..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I had another look and the biggest difference between atext_completion and acompletion in liteLLM is that acompletion receives a conversation with multiple prompts and roles and atext_completion just receives a prompt.

If this is OK for you then let's go ahead. I'm not sure if FIM will use any system prompts or such.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge and attempt to do the normalization of TextCompletionResponse to ModelResponse in a separate PR. I agree is better if we try to keep the stuff as normalized as possible.

super().__init__(
VLLMInputNormalizer(),
VLLMOutputNormalizer(),
Expand Down