From bbe3d79f4ecb4efb94cba5cd6ee758374166c5ea Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 10 Feb 2025 18:04:20 +0200 Subject: [PATCH 1/5] Pass down base URL and API key to completion handler In order to make Ollama work with API Keys, I needed to change the completion handler to take a base URL and also leverage a given API key (if available). Signed-off-by: Juan Antonio Osorio Co-Authored-by: Alejandro Ponce de Leon --- .../providers/anthropic/completion_handler.py | 1 + src/codegate/providers/base.py | 1 + src/codegate/providers/completion/base.py | 1 + .../providers/litellmshim/litellmshim.py | 2 ++ .../providers/ollama/completion_handler.py | 19 ++++++++++++++----- src/codegate/providers/ollama/provider.py | 4 ++-- 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py index da7e6537..12c007f4 100644 --- a/src/codegate/providers/anthropic/completion_handler.py +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -13,6 +13,7 @@ class AnthropicCompletion(LiteLLmShim): async def execute_completion( self, request: ChatCompletionRequest, + base_url: Optional[str], api_key: Optional[str], stream: bool = False, is_fim_request: bool = False, diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 0c20bab8..3edced69 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -257,6 +257,7 @@ async def complete( # based on the streaming flag model_response = await self._completion_handler.execute_completion( provider_request, + base_url=data.get("base_url"), api_key=api_key, stream=streaming, is_fim_request=is_fim_request, diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index 41c59082..084f6fc7 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -19,6 +19,7 @@ class BaseCompletionHandler(ABC): async def execute_completion( self, request: ChatCompletionRequest, + base_url: Optional[str], api_key: Optional[str], stream: bool = False, # TODO: remove this param? is_fim_request: bool = False, diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 888da7f4..37693f1d 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -41,6 +41,7 @@ def __init__( async def execute_completion( self, request: ChatCompletionRequest, + base_url: Optional[str], api_key: Optional[str], stream: bool = False, is_fim_request: bool = False, @@ -49,6 +50,7 @@ async def execute_completion( Execute the completion request with LiteLLM's API """ request["api_key"] = api_key + request["base_url"] = base_url if is_fim_request: return await self._fim_completion_func(**request) return await self._completion_func(**request) diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index 829b0161..b10c41f6 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -82,17 +82,25 @@ async def ollama_stream_generator( # noqa: C901 class OllamaShim(BaseCompletionHandler): - def __init__(self, base_url): - self.client = AsyncClient(host=base_url, timeout=30) - async def execute_completion( self, request: ChatCompletionRequest, + base_url: Optional[str], api_key: Optional[str], stream: bool = False, is_fim_request: bool = False, ) -> Union[ChatResponse, GenerateResponse]: """Stream response directly from Ollama API.""" + if not base_url: + raise ValueError("base_url is required for Ollama") + + # TODO: Add CodeGate user agent. + headers = dict() + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + client = AsyncClient(host=base_url, timeout=300, headers=headers) + try: if is_fim_request: prompt = "" @@ -103,7 +111,7 @@ async def execute_completion( if not prompt: raise ValueError("No user message found in FIM request") - response = await self.client.generate( + response = await client.generate( model=request["model"], prompt=prompt, raw=request.get("raw", False), @@ -112,7 +120,7 @@ async def execute_completion( options=request["options"], # type: ignore ) else: - response = await self.client.chat( + response = await client.chat( model=request["model"], messages=request["messages"], stream=stream, # type: ignore @@ -123,6 +131,7 @@ async def execute_completion( logger.error(f"Error in Ollama completion: {str(e)}") raise e + def _create_streaming_response( self, stream: AsyncIterator[ChatResponse], diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 4f5cd654..2bcd6d30 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -28,7 +28,7 @@ def __init__( else: provided_urls = config.provider_urls self.base_url = provided_urls.get("ollama", "http://localhost:11434/") - completion_handler = OllamaShim(self.base_url) + completion_handler = OllamaShim() super().__init__( OllamaInputNormalizer(), OllamaOutputNormalizer(), @@ -68,7 +68,7 @@ async def process_request( try: stream = await self.complete( data, - api_key=None, + api_key=api_key, is_fim_request=is_fim_request, client_type=client_type, ) From 9b44f5eba12f7fa80db95104edf19fd3686b6a80 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 11 Feb 2025 09:59:18 +0200 Subject: [PATCH 2/5] fix unit tests and formatting --- .../providers/ollama/completion_handler.py | 1 - .../providers/litellmshim/test_litellmshim.py | 4 +-- .../ollama/test_ollama_completion_handler.py | 31 +++++++++++++------ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index b10c41f6..73f53037 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -131,7 +131,6 @@ async def execute_completion( logger.error(f"Error in Ollama completion: {str(e)}") raise e - def _create_streaming_response( self, stream: AsyncIterator[ChatResponse], diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 87b75803..d381cdaa 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -56,7 +56,7 @@ async def test_complete_non_streaming(): } # Execute - result = await litellm_shim.execute_completion(data, api_key=None) + result = await litellm_shim.execute_completion(data, base_url=None, api_key=None) # Verify assert result == mock_response @@ -86,7 +86,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: # Execute result_stream = await litellm_shim.execute_completion( - ChatCompletionRequest(**data), api_key=None + ChatCompletionRequest(**data), base_url=None, api_key=None ) # Verify stream contents and adapter processing diff --git a/tests/providers/ollama/test_ollama_completion_handler.py b/tests/providers/ollama/test_ollama_completion_handler.py index df0eb149..7341dfe3 100644 --- a/tests/providers/ollama/test_ollama_completion_handler.py +++ b/tests/providers/ollama/test_ollama_completion_handler.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from litellm import ChatCompletionRequest @@ -19,8 +19,7 @@ def mock_client(): @pytest.fixture def handler(mock_client): - ollama_shim = OllamaShim("http://ollama:11434") - ollama_shim.client = mock_client + ollama_shim = OllamaShim() return ollama_shim @@ -31,11 +30,18 @@ def chat_request(): ) +@patch("codegate.providers.ollama.completion_handler.AsyncClient.generate", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_execute_completion_is_fim_request(handler, chat_request): +async def test_execute_completion_is_fim_request(mock_client_generate, handler, chat_request): chat_request["messages"][0]["content"] = "FIM prompt" - await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=True) - handler.client.generate.assert_called_once_with( + await handler.execute_completion( + chat_request, + base_url="http://ollama:11434", + api_key=None, + stream=False, + is_fim_request=True, + ) + mock_client_generate.assert_called_once_with( model=chat_request["model"], prompt="FIM prompt", stream=False, @@ -45,10 +51,17 @@ async def test_execute_completion_is_fim_request(handler, chat_request): ) +@patch("codegate.providers.ollama.completion_handler.AsyncClient.chat", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_execute_completion_not_is_fim_request(handler, chat_request): - await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=False) - handler.client.chat.assert_called_once_with( +async def test_execute_completion_not_is_fim_request(mock_client_chat, handler, chat_request): + await handler.execute_completion( + chat_request, + base_url="http://ollama:11434", + api_key=None, + stream=False, + is_fim_request=False, + ) + mock_client_chat.assert_called_once_with( model=chat_request["model"], messages=chat_request["messages"], stream=False, From 05ba85bc95bbdba09b8d108b42ee4adc42cfcf96 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 11 Feb 2025 10:44:26 +0200 Subject: [PATCH 3/5] fix integration tests --- src/codegate/providers/anthropic/completion_handler.py | 8 +++++++- src/codegate/providers/anthropic/provider.py | 5 +++-- src/codegate/providers/llamacpp/completion_handler.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py index 12c007f4..8d23ee21 100644 --- a/src/codegate/providers/anthropic/completion_handler.py +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -32,4 +32,10 @@ async def execute_completion( model_in_request = request["model"] if not model_in_request.startswith("anthropic/"): request["model"] = f"anthropic/{model_in_request}" - return await super().execute_completion(request, api_key, stream, is_fim_request) + return await super().execute_completion( + request=request, + api_key=api_key, + stream=stream, + is_fim_request=is_fim_request, + base_url=request.get("base_url"), + ) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 252a6947..454018fd 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -14,6 +14,8 @@ from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.litellmshim import anthropic_stream_generator +logger = structlog.get_logger("codegate") + class AnthropicProvider(BaseProvider): def __init__( @@ -67,8 +69,7 @@ async def process_request( #  check if we have an status code there if hasattr(e, "status_code"): # log the exception - logger = structlog.get_logger("codegate") - logger.error("Error in AnthropicProvider completion", error=str(e)) + logger.exception("Error in AnthropicProvider completion") raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore else: # just continue raising the exception diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index e699b2ff..ef34610a 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -50,6 +50,7 @@ def __init__(self): async def execute_completion( self, request: ChatCompletionRequest, + base_url: Optional[str], api_key: Optional[str], stream: bool = False, is_fim_request: bool = False, From d5281d26833e95f6c31db1df82ca635715b4a20c Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 11 Feb 2025 10:51:16 +0200 Subject: [PATCH 4/5] Pass API key to ollama calls Signed-off-by: Juan Antonio Osorio --- src/codegate/providers/ollama/provider.py | 36 +++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 2bcd6d30..3953acc5 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -3,7 +3,7 @@ import httpx import structlog -from fastapi import HTTPException, Request +from fastapi import Header, HTTPException, Request from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient @@ -103,21 +103,28 @@ async def get_tags(request: Request): return response.json() @self.router.post(f"/{self.provider_route_name}/api/show") - async def show_model(request: Request): + async def show_model( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): """ route for /api/show that responds outside of the pipeline /api/show displays model is used to get the model information https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information """ + api_key = _api_key_from_optional_header_value(authorization) body = await request.body() body_json = json.loads(body) if "name" not in body_json: raise HTTPException(status_code=400, detail="model is required in the request body") async with httpx.AsyncClient() as client: + headers = {"Content-Type": "application/json; charset=utf-8"} + if api_key: + headers["Authorization"] = api_key response = await client.post( f"{self.base_url}/api/show", content=body, - headers={"Content-Type": "application/json; charset=utf-8"}, + headers=headers, ) return response.json() @@ -131,7 +138,11 @@ async def show_model(request: Request): @self.router.post(f"/{self.provider_route_name}/v1/chat/completions") @self.router.post(f"/{self.provider_route_name}/v1/generate") @DetectClient() - async def create_completion(request: Request): + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + api_key = _api_key_from_optional_header_value(authorization) body = await request.body() data = json.loads(body) @@ -141,7 +152,22 @@ async def create_completion(request: Request): is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, - None, + api_key, is_fim_request, request.state.detected_client, ) + + +def _api_key_from_optional_header_value(val: str) -> str: + # The header is optional, so if we don't + # have it, let's just return None + if not val: + return None + + # The header value should be "Beaerer " + if not val.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + vals = val.split(" ") + if len(vals) != 2: + raise HTTPException(status_code=401, detail="Invalid authorization header") + return vals[1] From 37695478dccb6994a931c992f875ed2d39276f89 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 11 Feb 2025 12:56:11 +0200 Subject: [PATCH 5/5] Handle optional authorization header in ollama endpoints. Signed-off-by: Juan Antonio Osorio --- src/codegate/providers/ollama/provider.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 3953acc5..c1e30909 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,5 +1,5 @@ import json -from typing import List +from typing import List, Optional import httpx import structlog @@ -105,7 +105,7 @@ async def get_tags(request: Request): @self.router.post(f"/{self.provider_route_name}/api/show") async def show_model( request: Request, - authorization: str = Header(..., description="Bearer token"), + authorization: str | None = Header(None, description="Bearer token"), ): """ route for /api/show that responds outside of the pipeline @@ -140,7 +140,7 @@ async def show_model( @DetectClient() async def create_completion( request: Request, - authorization: str = Header(..., description="Bearer token"), + authorization: str | None = Header(None, description="Bearer token"), ): api_key = _api_key_from_optional_header_value(authorization) body = await request.body()