Skip to content

Pass down base URL and API key to completion handler #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/codegate/providers/anthropic/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class AnthropicCompletion(LiteLLmShim):
async def execute_completion(
self,
request: ChatCompletionRequest,
base_url: Optional[str],
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
Expand All @@ -31,4 +32,10 @@ async def execute_completion(
model_in_request = request["model"]
if not model_in_request.startswith("anthropic/"):
request["model"] = f"anthropic/{model_in_request}"
return await super().execute_completion(request, api_key, stream, is_fim_request)
return await super().execute_completion(
request=request,
api_key=api_key,
stream=stream,
is_fim_request=is_fim_request,
base_url=request.get("base_url"),
)
5 changes: 3 additions & 2 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import anthropic_stream_generator

logger = structlog.get_logger("codegate")


class AnthropicProvider(BaseProvider):
def __init__(
Expand Down Expand Up @@ -67,8 +69,7 @@ async def process_request(
#  check if we have an status code there
if hasattr(e, "status_code"):
# log the exception
logger = structlog.get_logger("codegate")
logger.error("Error in AnthropicProvider completion", error=str(e))
logger.exception("Error in AnthropicProvider completion")
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
else:
# just continue raising the exception
Expand Down
1 change: 1 addition & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ async def complete(
# based on the streaming flag
model_response = await self._completion_handler.execute_completion(
provider_request,
base_url=data.get("base_url"),
api_key=api_key,
stream=streaming,
is_fim_request=is_fim_request,
Expand Down
1 change: 1 addition & 0 deletions src/codegate/providers/completion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class BaseCompletionHandler(ABC):
async def execute_completion(
self,
request: ChatCompletionRequest,
base_url: Optional[str],
api_key: Optional[str],
stream: bool = False, # TODO: remove this param?
is_fim_request: bool = False,
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
async def execute_completion(
self,
request: ChatCompletionRequest,
base_url: Optional[str],
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
Expand All @@ -49,6 +50,7 @@ async def execute_completion(
Execute the completion request with LiteLLM's API
"""
request["api_key"] = api_key
request["base_url"] = base_url
if is_fim_request:
return await self._fim_completion_func(**request)
return await self._completion_func(**request)
Expand Down
1 change: 1 addition & 0 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self):
async def execute_completion(
self,
request: ChatCompletionRequest,
base_url: Optional[str],
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
Expand Down
18 changes: 13 additions & 5 deletions src/codegate/providers/ollama/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,25 @@ async def ollama_stream_generator( # noqa: C901

class OllamaShim(BaseCompletionHandler):

def __init__(self, base_url):
self.client = AsyncClient(host=base_url, timeout=30)

async def execute_completion(
self,
request: ChatCompletionRequest,
base_url: Optional[str],
api_key: Optional[str],
stream: bool = False,
is_fim_request: bool = False,
) -> Union[ChatResponse, GenerateResponse]:
"""Stream response directly from Ollama API."""
if not base_url:
raise ValueError("base_url is required for Ollama")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we default to localhost:11434 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should already get this default frm the caller.


# TODO: Add CodeGate user agent.
headers = dict()
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the bearer token get forwarded or does it have some other meaning later on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It gets set so we can forward it


client = AsyncClient(host=base_url, timeout=300, headers=headers)

try:
if is_fim_request:
prompt = ""
Expand All @@ -103,7 +111,7 @@ async def execute_completion(
if not prompt:
raise ValueError("No user message found in FIM request")

response = await self.client.generate(
response = await client.generate(
model=request["model"],
prompt=prompt,
raw=request.get("raw", False),
Expand All @@ -112,7 +120,7 @@ async def execute_completion(
options=request["options"], # type: ignore
)
else:
response = await self.client.chat(
response = await client.chat(
model=request["model"],
messages=request["messages"],
stream=stream, # type: ignore
Expand Down
42 changes: 34 additions & 8 deletions src/codegate/providers/ollama/provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from typing import List
from typing import List, Optional

import httpx
import structlog
from fastapi import HTTPException, Request
from fastapi import Header, HTTPException, Request

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
Expand All @@ -28,7 +28,7 @@ def __init__(
else:
provided_urls = config.provider_urls
self.base_url = provided_urls.get("ollama", "http://localhost:11434/")
completion_handler = OllamaShim(self.base_url)
completion_handler = OllamaShim()
super().__init__(
OllamaInputNormalizer(),
OllamaOutputNormalizer(),
Expand Down Expand Up @@ -68,7 +68,7 @@ async def process_request(
try:
stream = await self.complete(
data,
api_key=None,
api_key=api_key,
is_fim_request=is_fim_request,
client_type=client_type,
)
Expand Down Expand Up @@ -103,21 +103,28 @@ async def get_tags(request: Request):
return response.json()

@self.router.post(f"/{self.provider_route_name}/api/show")
async def show_model(request: Request):
async def show_model(
request: Request,
authorization: str | None = Header(None, description="Bearer token"),
):
"""
route for /api/show that responds outside of the pipeline
/api/show displays model is used to get the model information
https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
"""
api_key = _api_key_from_optional_header_value(authorization)
body = await request.body()
body_json = json.loads(body)
if "name" not in body_json:
raise HTTPException(status_code=400, detail="model is required in the request body")
async with httpx.AsyncClient() as client:
headers = {"Content-Type": "application/json; charset=utf-8"}
if api_key:
headers["Authorization"] = api_key
response = await client.post(
f"{self.base_url}/api/show",
content=body,
headers={"Content-Type": "application/json; charset=utf-8"},
headers=headers,
)
return response.json()

Expand All @@ -131,7 +138,11 @@ async def show_model(request: Request):
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
@self.router.post(f"/{self.provider_route_name}/v1/generate")
@DetectClient()
async def create_completion(request: Request):
async def create_completion(
request: Request,
authorization: str | None = Header(None, description="Bearer token"),
):
api_key = _api_key_from_optional_header_value(authorization)
body = await request.body()
data = json.loads(body)

Expand All @@ -141,7 +152,22 @@ async def create_completion(request: Request):
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
None,
api_key,
is_fim_request,
request.state.detected_client,
)


def _api_key_from_optional_header_value(val: str) -> str:
# The header is optional, so if we don't
# have it, let's just return None
if not val:
return None

# The header value should be "Beaerer <key>"
if not val.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
vals = val.split(" ")
if len(vals) != 2:
raise HTTPException(status_code=401, detail="Invalid authorization header")
return vals[1]
4 changes: 2 additions & 2 deletions tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test_complete_non_streaming():
}

# Execute
result = await litellm_shim.execute_completion(data, api_key=None)
result = await litellm_shim.execute_completion(data, base_url=None, api_key=None)

# Verify
assert result == mock_response
Expand Down Expand Up @@ -86,7 +86,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:

# Execute
result_stream = await litellm_shim.execute_completion(
ChatCompletionRequest(**data), api_key=None
ChatCompletionRequest(**data), base_url=None, api_key=None
)

# Verify stream contents and adapter processing
Expand Down
31 changes: 22 additions & 9 deletions tests/providers/ollama/test_ollama_completion_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from litellm import ChatCompletionRequest
Expand All @@ -19,8 +19,7 @@ def mock_client():

@pytest.fixture
def handler(mock_client):
ollama_shim = OllamaShim("http://ollama:11434")
ollama_shim.client = mock_client
ollama_shim = OllamaShim()
return ollama_shim


Expand All @@ -31,11 +30,18 @@ def chat_request():
)


@patch("codegate.providers.ollama.completion_handler.AsyncClient.generate", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_execute_completion_is_fim_request(handler, chat_request):
async def test_execute_completion_is_fim_request(mock_client_generate, handler, chat_request):
chat_request["messages"][0]["content"] = "FIM prompt"
await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=True)
handler.client.generate.assert_called_once_with(
await handler.execute_completion(
chat_request,
base_url="http://ollama:11434",
api_key=None,
stream=False,
is_fim_request=True,
)
mock_client_generate.assert_called_once_with(
model=chat_request["model"],
prompt="FIM prompt",
stream=False,
Expand All @@ -45,10 +51,17 @@ async def test_execute_completion_is_fim_request(handler, chat_request):
)


@patch("codegate.providers.ollama.completion_handler.AsyncClient.chat", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_execute_completion_not_is_fim_request(handler, chat_request):
await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=False)
handler.client.chat.assert_called_once_with(
async def test_execute_completion_not_is_fim_request(mock_client_chat, handler, chat_request):
await handler.execute_completion(
chat_request,
base_url="http://ollama:11434",
api_key=None,
stream=False,
is_fim_request=False,
)
mock_client_chat.assert_called_once_with(
model=chat_request["model"],
messages=chat_request["messages"],
stream=False,
Expand Down