Skip to content

Commit 6938347

Browse files
Pass down base URL and API key to completion handler (#1002)
* Pass down base URL and API key to completion handler In order to make Ollama work with API Keys, I needed to change the completion handler to take a base URL and also leverage a given API key (if available). Signed-off-by: Juan Antonio Osorio <[email protected]> Co-Authored-by: Alejandro Ponce de Leon <[email protected]> * fix unit tests and formatting * fix integration tests * Pass API key to ollama calls Signed-off-by: Juan Antonio Osorio <[email protected]> * Handle optional authorization header in ollama endpoints. Signed-off-by: Juan Antonio Osorio <[email protected]> --------- Signed-off-by: Juan Antonio Osorio <[email protected]> Co-authored-by: Alejandro Ponce de Leon <[email protected]>
1 parent 46a5fd0 commit 6938347

File tree

10 files changed

+87
-27
lines changed

10 files changed

+87
-27
lines changed

src/codegate/providers/anthropic/completion_handler.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class AnthropicCompletion(LiteLLmShim):
1313
async def execute_completion(
1414
self,
1515
request: ChatCompletionRequest,
16+
base_url: Optional[str],
1617
api_key: Optional[str],
1718
stream: bool = False,
1819
is_fim_request: bool = False,
@@ -31,4 +32,10 @@ async def execute_completion(
3132
model_in_request = request["model"]
3233
if not model_in_request.startswith("anthropic/"):
3334
request["model"] = f"anthropic/{model_in_request}"
34-
return await super().execute_completion(request, api_key, stream, is_fim_request)
35+
return await super().execute_completion(
36+
request=request,
37+
api_key=api_key,
38+
stream=stream,
39+
is_fim_request=is_fim_request,
40+
base_url=request.get("base_url"),
41+
)

src/codegate/providers/anthropic/provider.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from codegate.providers.fim_analyzer import FIMAnalyzer
1515
from codegate.providers.litellmshim import anthropic_stream_generator
1616

17+
logger = structlog.get_logger("codegate")
18+
1719

1820
class AnthropicProvider(BaseProvider):
1921
def __init__(
@@ -67,8 +69,7 @@ async def process_request(
6769
#  check if we have an status code there
6870
if hasattr(e, "status_code"):
6971
# log the exception
70-
logger = structlog.get_logger("codegate")
71-
logger.error("Error in AnthropicProvider completion", error=str(e))
72+
logger.exception("Error in AnthropicProvider completion")
7273
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
7374
else:
7475
# just continue raising the exception

src/codegate/providers/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ async def complete(
257257
# based on the streaming flag
258258
model_response = await self._completion_handler.execute_completion(
259259
provider_request,
260+
base_url=data.get("base_url"),
260261
api_key=api_key,
261262
stream=streaming,
262263
is_fim_request=is_fim_request,

src/codegate/providers/completion/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class BaseCompletionHandler(ABC):
1919
async def execute_completion(
2020
self,
2121
request: ChatCompletionRequest,
22+
base_url: Optional[str],
2223
api_key: Optional[str],
2324
stream: bool = False, # TODO: remove this param?
2425
is_fim_request: bool = False,

src/codegate/providers/litellmshim/litellmshim.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
async def execute_completion(
4242
self,
4343
request: ChatCompletionRequest,
44+
base_url: Optional[str],
4445
api_key: Optional[str],
4546
stream: bool = False,
4647
is_fim_request: bool = False,
@@ -49,6 +50,7 @@ async def execute_completion(
4950
Execute the completion request with LiteLLM's API
5051
"""
5152
request["api_key"] = api_key
53+
request["base_url"] = base_url
5254
if is_fim_request:
5355
return await self._fim_completion_func(**request)
5456
return await self._completion_func(**request)

src/codegate/providers/llamacpp/completion_handler.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self):
5050
async def execute_completion(
5151
self,
5252
request: ChatCompletionRequest,
53+
base_url: Optional[str],
5354
api_key: Optional[str],
5455
stream: bool = False,
5556
is_fim_request: bool = False,

src/codegate/providers/ollama/completion_handler.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,25 @@ async def ollama_stream_generator( # noqa: C901
8282

8383
class OllamaShim(BaseCompletionHandler):
8484

85-
def __init__(self, base_url):
86-
self.client = AsyncClient(host=base_url, timeout=30)
87-
8885
async def execute_completion(
8986
self,
9087
request: ChatCompletionRequest,
88+
base_url: Optional[str],
9189
api_key: Optional[str],
9290
stream: bool = False,
9391
is_fim_request: bool = False,
9492
) -> Union[ChatResponse, GenerateResponse]:
9593
"""Stream response directly from Ollama API."""
94+
if not base_url:
95+
raise ValueError("base_url is required for Ollama")
96+
97+
# TODO: Add CodeGate user agent.
98+
headers = dict()
99+
if api_key:
100+
headers["Authorization"] = f"Bearer {api_key}"
101+
102+
client = AsyncClient(host=base_url, timeout=300, headers=headers)
103+
96104
try:
97105
if is_fim_request:
98106
prompt = ""
@@ -103,7 +111,7 @@ async def execute_completion(
103111
if not prompt:
104112
raise ValueError("No user message found in FIM request")
105113

106-
response = await self.client.generate(
114+
response = await client.generate(
107115
model=request["model"],
108116
prompt=prompt,
109117
raw=request.get("raw", False),
@@ -112,7 +120,7 @@ async def execute_completion(
112120
options=request["options"], # type: ignore
113121
)
114122
else:
115-
response = await self.client.chat(
123+
response = await client.chat(
116124
model=request["model"],
117125
messages=request["messages"],
118126
stream=stream, # type: ignore

src/codegate/providers/ollama/provider.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import json
2-
from typing import List
2+
from typing import List, Optional
33

44
import httpx
55
import structlog
6-
from fastapi import HTTPException, Request
6+
from fastapi import Header, HTTPException, Request
77

88
from codegate.clients.clients import ClientType
99
from codegate.clients.detector import DetectClient
@@ -28,7 +28,7 @@ def __init__(
2828
else:
2929
provided_urls = config.provider_urls
3030
self.base_url = provided_urls.get("ollama", "http://localhost:11434/")
31-
completion_handler = OllamaShim(self.base_url)
31+
completion_handler = OllamaShim()
3232
super().__init__(
3333
OllamaInputNormalizer(),
3434
OllamaOutputNormalizer(),
@@ -68,7 +68,7 @@ async def process_request(
6868
try:
6969
stream = await self.complete(
7070
data,
71-
api_key=None,
71+
api_key=api_key,
7272
is_fim_request=is_fim_request,
7373
client_type=client_type,
7474
)
@@ -103,21 +103,28 @@ async def get_tags(request: Request):
103103
return response.json()
104104

105105
@self.router.post(f"/{self.provider_route_name}/api/show")
106-
async def show_model(request: Request):
106+
async def show_model(
107+
request: Request,
108+
authorization: str | None = Header(None, description="Bearer token"),
109+
):
107110
"""
108111
route for /api/show that responds outside of the pipeline
109112
/api/show displays model is used to get the model information
110113
https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
111114
"""
115+
api_key = _api_key_from_optional_header_value(authorization)
112116
body = await request.body()
113117
body_json = json.loads(body)
114118
if "name" not in body_json:
115119
raise HTTPException(status_code=400, detail="model is required in the request body")
116120
async with httpx.AsyncClient() as client:
121+
headers = {"Content-Type": "application/json; charset=utf-8"}
122+
if api_key:
123+
headers["Authorization"] = api_key
117124
response = await client.post(
118125
f"{self.base_url}/api/show",
119126
content=body,
120-
headers={"Content-Type": "application/json; charset=utf-8"},
127+
headers=headers,
121128
)
122129
return response.json()
123130

@@ -131,7 +138,11 @@ async def show_model(request: Request):
131138
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
132139
@self.router.post(f"/{self.provider_route_name}/v1/generate")
133140
@DetectClient()
134-
async def create_completion(request: Request):
141+
async def create_completion(
142+
request: Request,
143+
authorization: str | None = Header(None, description="Bearer token"),
144+
):
145+
api_key = _api_key_from_optional_header_value(authorization)
135146
body = await request.body()
136147
data = json.loads(body)
137148

@@ -141,7 +152,22 @@ async def create_completion(request: Request):
141152
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
142153
return await self.process_request(
143154
data,
144-
None,
155+
api_key,
145156
is_fim_request,
146157
request.state.detected_client,
147158
)
159+
160+
161+
def _api_key_from_optional_header_value(val: str) -> str:
162+
# The header is optional, so if we don't
163+
# have it, let's just return None
164+
if not val:
165+
return None
166+
167+
# The header value should be "Beaerer <key>"
168+
if not val.startswith("Bearer "):
169+
raise HTTPException(status_code=401, detail="Invalid authorization header")
170+
vals = val.split(" ")
171+
if len(vals) != 2:
172+
raise HTTPException(status_code=401, detail="Invalid authorization header")
173+
return vals[1]

tests/providers/litellmshim/test_litellmshim.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def test_complete_non_streaming():
5656
}
5757

5858
# Execute
59-
result = await litellm_shim.execute_completion(data, api_key=None)
59+
result = await litellm_shim.execute_completion(data, base_url=None, api_key=None)
6060

6161
# Verify
6262
assert result == mock_response
@@ -86,7 +86,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
8686

8787
# Execute
8888
result_stream = await litellm_shim.execute_completion(
89-
ChatCompletionRequest(**data), api_key=None
89+
ChatCompletionRequest(**data), base_url=None, api_key=None
9090
)
9191

9292
# Verify stream contents and adapter processing

tests/providers/ollama/test_ollama_completion_handler.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44
from litellm import ChatCompletionRequest
@@ -19,8 +19,7 @@ def mock_client():
1919

2020
@pytest.fixture
2121
def handler(mock_client):
22-
ollama_shim = OllamaShim("http://ollama:11434")
23-
ollama_shim.client = mock_client
22+
ollama_shim = OllamaShim()
2423
return ollama_shim
2524

2625

@@ -31,11 +30,18 @@ def chat_request():
3130
)
3231

3332

33+
@patch("codegate.providers.ollama.completion_handler.AsyncClient.generate", new_callable=AsyncMock)
3434
@pytest.mark.asyncio
35-
async def test_execute_completion_is_fim_request(handler, chat_request):
35+
async def test_execute_completion_is_fim_request(mock_client_generate, handler, chat_request):
3636
chat_request["messages"][0]["content"] = "FIM prompt"
37-
await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=True)
38-
handler.client.generate.assert_called_once_with(
37+
await handler.execute_completion(
38+
chat_request,
39+
base_url="http://ollama:11434",
40+
api_key=None,
41+
stream=False,
42+
is_fim_request=True,
43+
)
44+
mock_client_generate.assert_called_once_with(
3945
model=chat_request["model"],
4046
prompt="FIM prompt",
4147
stream=False,
@@ -45,10 +51,17 @@ async def test_execute_completion_is_fim_request(handler, chat_request):
4551
)
4652

4753

54+
@patch("codegate.providers.ollama.completion_handler.AsyncClient.chat", new_callable=AsyncMock)
4855
@pytest.mark.asyncio
49-
async def test_execute_completion_not_is_fim_request(handler, chat_request):
50-
await handler.execute_completion(chat_request, api_key=None, stream=False, is_fim_request=False)
51-
handler.client.chat.assert_called_once_with(
56+
async def test_execute_completion_not_is_fim_request(mock_client_chat, handler, chat_request):
57+
await handler.execute_completion(
58+
chat_request,
59+
base_url="http://ollama:11434",
60+
api_key=None,
61+
stream=False,
62+
is_fim_request=False,
63+
)
64+
mock_client_chat.assert_called_once_with(
5265
model=chat_request["model"],
5366
messages=chat_request["messages"],
5467
stream=False,

0 commit comments

Comments
 (0)