Skip to content

Commit dd3db9f

Browse files
committed
Pass API key to ollama calls
Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 56fad9f commit dd3db9f

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

src/codegate/providers/ollama/provider.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import httpx
55
import structlog
6-
from fastapi import HTTPException, Request
6+
from fastapi import Header, HTTPException, Request
77

88
from codegate.clients.clients import ClientType
99
from codegate.clients.detector import DetectClient
@@ -103,21 +103,28 @@ async def get_tags(request: Request):
103103
return response.json()
104104

105105
@self.router.post(f"/{self.provider_route_name}/api/show")
106-
async def show_model(request: Request):
106+
async def show_model(
107+
request: Request,
108+
authorization: str = Header(..., description="Bearer token"),
109+
):
107110
"""
108111
route for /api/show that responds outside of the pipeline
109112
/api/show displays model is used to get the model information
110113
https://github.com/ollama/ollama/blob/main/docs/api.md#show-model-information
111114
"""
115+
api_key = _api_key_from_optional_header_value(authorization)
112116
body = await request.body()
113117
body_json = json.loads(body)
114118
if "name" not in body_json:
115119
raise HTTPException(status_code=400, detail="model is required in the request body")
116120
async with httpx.AsyncClient() as client:
121+
headers = {"Content-Type": "application/json; charset=utf-8"}
122+
if api_key:
123+
headers["Authorization"] = api_key
117124
response = await client.post(
118125
f"{self.base_url}/api/show",
119126
content=body,
120-
headers={"Content-Type": "application/json; charset=utf-8"},
127+
headers=headers,
121128
)
122129
return response.json()
123130

@@ -131,7 +138,11 @@ async def show_model(request: Request):
131138
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
132139
@self.router.post(f"/{self.provider_route_name}/v1/generate")
133140
@DetectClient()
134-
async def create_completion(request: Request):
141+
async def create_completion(
142+
request: Request,
143+
authorization: str = Header(..., description="Bearer token"),
144+
):
145+
api_key = _api_key_from_optional_header_value(authorization)
135146
body = await request.body()
136147
data = json.loads(body)
137148

@@ -141,7 +152,22 @@ async def create_completion(request: Request):
141152
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
142153
return await self.process_request(
143154
data,
144-
None,
155+
api_key,
145156
is_fim_request,
146157
request.state.detected_client,
147158
)
159+
160+
161+
def _api_key_from_optional_header_value(val: str) -> str:
162+
# The header is optional, so if we don't
163+
# have it, let's just return None
164+
if not val:
165+
return None
166+
167+
# The header value should be "Beaerer <key>"
168+
if not val.startswith("Bearer "):
169+
raise HTTPException(status_code=401, detail="Invalid authorization header")
170+
vals = val.split(" ")
171+
if len(vals) != 2:
172+
raise HTTPException(status_code=401, detail="Invalid authorization header")
173+
return vals[1]

0 commit comments

Comments
 (0)