Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1179eb4

Browse files
Merge pull request #152 from stacklok/add-ollama-to-pipeline
Use pipelines in Ollama provider
2 parents 06d8baa + 2d97e2c commit 1179eb4

File tree

8 files changed

+123
-150
lines changed

8 files changed

+123
-150
lines changed

src/codegate/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ def show_prompts(prompts: Optional[Path]) -> None:
107107
default=None,
108108
help="Anthropic provider URL (default: https://api.anthropic.com/v1)",
109109
)
110+
@click.option(
111+
"--ollama-url",
112+
type=str,
113+
default=None,
114+
help="Ollama provider URL (default: http://localhost:11434/api)",
115+
)
110116
def serve(
111117
port: Optional[int],
112118
host: Optional[str],
@@ -117,6 +123,7 @@ def serve(
117123
vllm_url: Optional[str],
118124
openai_url: Optional[str],
119125
anthropic_url: Optional[str],
126+
ollama_url: Optional[str],
120127
) -> None:
121128
"""Start the codegate server."""
122129
logger = None
@@ -129,6 +136,8 @@ def serve(
129136
cli_provider_urls["openai"] = openai_url
130137
if anthropic_url:
131138
cli_provider_urls["anthropic"] = anthropic_url
139+
if ollama_url:
140+
cli_provider_urls["ollama"] = ollama_url
132141

133142
# Load configuration with priority resolution
134143
cfg = Config.load(

src/codegate/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"openai": "https://api.openai.com/v1",
2020
"anthropic": "https://api.anthropic.com/v1",
2121
"vllm": "http://localhost:8000", # Base URL without /v1 path
22-
"ollama": "http://localhost:11434", # Default Ollama server URL
22+
"ollama": "http://localhost:11434/api", # Default Ollama server URL
2323
}
2424

2525

src/codegate/providers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def _is_fim_request_url(self, request: Request) -> bool:
9595
if request_path.endswith("/chat/completions"):
9696
return False
9797

98-
if request_path.endswith("/completions"):
98+
# /completions is for OpenAI standard. /api/generate is for ollama.
99+
if request_path.endswith("/completions") or request_path.endswith("/api/generate"):
99100
return True
100101

101102
return False

src/codegate/providers/completion/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from abc import ABC, abstractmethod
23
from collections.abc import Iterator
34
from typing import Any, AsyncIterator, Optional, Union
@@ -35,6 +36,6 @@ def create_response(self, response: Any) -> Union[JSONResponse, StreamingRespons
3536
"""
3637
Create a FastAPI response from the completion response.
3738
"""
38-
if isinstance(response, Iterator):
39+
if isinstance(response, Iterator) or inspect.isasyncgen(response):
3940
return self._create_streaming_response(response)
4041
return self._create_json_response(response)

src/codegate/providers/ollama/adapter.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,34 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1515
"""
1616
# Make a copy of the data to avoid modifying the original
1717
normalized_data = data.copy()
18+
normalized_data["options"] = data.get("options", {})
19+
20+
# Add any context or system prompt if provided
21+
if "context" in data:
22+
normalized_data["context"] = data["context"]
23+
if "system" in data:
24+
normalized_data["system"] = data["system"]
1825

1926
# Format the model name
2027
if "model" in normalized_data:
21-
normalized_data["model"] = normalized_data["model"].strip()
28+
normalized_data["model"] = data["model"].strip()
2229

2330
# Convert messages format if needed
24-
if "messages" in normalized_data:
25-
messages = normalized_data["messages"]
31+
if "messages" in data:
32+
messages = data["messages"]
2633
converted_messages = []
2734
for msg in messages:
28-
if isinstance(msg.get("content"), list):
35+
role = msg.get("role", "")
36+
content = msg.get("content", "")
37+
new_msg = {"role": role, "content": content}
38+
if isinstance(content, list):
2939
# Convert list format to string
3040
content_parts = []
3141
for part in msg["content"]:
3242
if part.get("type") == "text":
3343
content_parts.append(part["text"])
34-
msg = msg.copy()
35-
msg["content"] = " ".join(content_parts)
36-
converted_messages.append(msg)
44+
new_msg["content"] = " ".join(content_parts)
45+
converted_messages.append(new_msg)
3746
normalized_data["messages"] = converted_messages
3847

3948
# Ensure the base_url ends with /api if provided
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import asyncio
2+
import json
3+
from typing import Any, AsyncIterator, Optional
4+
5+
import httpx
6+
import structlog
7+
from fastapi.responses import JSONResponse, StreamingResponse
8+
from litellm import ChatCompletionRequest
9+
10+
from codegate.providers.base import BaseCompletionHandler
11+
12+
logger = structlog.get_logger("codegate")
13+
14+
15+
async def get_async_ollama_response(client, request_url, data):
16+
try:
17+
async with client.stream("POST", request_url, json=data, timeout=30.0) as response:
18+
response.raise_for_status()
19+
async for line in response.aiter_lines():
20+
if line.strip():
21+
try:
22+
# Parse the response to ensure it's valid JSON
23+
response_data = json.loads(line)
24+
# Add newline to ensure proper streaming
25+
yield line.encode("utf-8") + b"\n"
26+
# If this is the final response, break
27+
if response_data.get("done", False):
28+
break
29+
# Small delay to prevent overwhelming the client
30+
await asyncio.sleep(0.01)
31+
except json.JSONDecodeError:
32+
yield json.dumps({"error": "Invalid JSON response"}).encode("utf-8") + b"\n"
33+
break
34+
except Exception as e:
35+
yield json.dumps({"error": str(e)}).encode("utf-8") + b"\n"
36+
break
37+
except Exception as e:
38+
yield json.dumps({"error": f"Stream error: {str(e)}"}).encode("utf-8") + b"\n"
39+
40+
41+
class OllamaCompletionHandler(BaseCompletionHandler):
42+
def __init__(self):
43+
self.client = httpx.AsyncClient(timeout=30.0)
44+
# Depends if the request is Chat or FIM
45+
self._url_mapping = {False: "/chat", True: "/generate"}
46+
47+
async def execute_completion(
48+
self,
49+
request: ChatCompletionRequest,
50+
api_key: Optional[str],
51+
stream: bool = False,
52+
is_fim_request: bool = False,
53+
) -> AsyncIterator:
54+
"""Stream response directly from Ollama API."""
55+
request_path = self._url_mapping[is_fim_request]
56+
request_url = f"{request['base_url']}{request_path}"
57+
return get_async_ollama_response(self.client, request_url, request)
58+
59+
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
60+
"""
61+
Create a streaming response from a stream generator. The StreamingResponse
62+
is the format that FastAPI expects for streaming responses.
63+
"""
64+
return StreamingResponse(
65+
stream,
66+
media_type="application/x-ndjson",
67+
headers={
68+
"Cache-Control": "no-cache",
69+
"Connection": "keep-alive",
70+
},
71+
)
72+
73+
def _create_json_response(self, response: Any) -> JSONResponse:
74+
raise NotImplementedError("JSON Reponse in Ollama not implemented yet.")
Lines changed: 19 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,12 @@
1-
import asyncio
21
import json
32
from typing import Optional
43

5-
import httpx
6-
from fastapi import Header, HTTPException, Request
7-
from fastapi.responses import StreamingResponse
4+
from fastapi import Request
85

96
from codegate.config import Config
107
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
11-
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
128
from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer
13-
14-
15-
async def stream_ollama_response(client: httpx.AsyncClient, url: str, data: dict):
16-
"""Stream response directly from Ollama API."""
17-
try:
18-
async with client.stream("POST", url, json=data, timeout=30.0) as response:
19-
response.raise_for_status()
20-
async for line in response.aiter_lines():
21-
if line.strip():
22-
try:
23-
# Parse the response to ensure it's valid JSON
24-
response_data = json.loads(line)
25-
# Add newline to ensure proper streaming
26-
yield line.encode("utf-8") + b"\n"
27-
# If this is the final response, break
28-
if response_data.get("done", False):
29-
break
30-
# Small delay to prevent overwhelming the client
31-
await asyncio.sleep(0.01)
32-
except json.JSONDecodeError:
33-
yield json.dumps({"error": "Invalid JSON response"}).encode("utf-8") + b"\n"
34-
break
35-
except Exception as e:
36-
yield json.dumps({"error": str(e)}).encode("utf-8") + b"\n"
37-
break
38-
except Exception as e:
39-
yield json.dumps({"error": f"Stream error: {str(e)}"}).encode("utf-8") + b"\n"
9+
from codegate.providers.ollama.completion_handler import OllamaCompletionHandler
4010

4111

4212
class OllamaProvider(BaseProvider):
@@ -45,15 +15,21 @@ def __init__(
4515
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
4616
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
4717
):
48-
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
18+
completion_handler = OllamaCompletionHandler()
4919
super().__init__(
5020
OllamaInputNormalizer(),
5121
OllamaOutputNormalizer(),
5222
completion_handler,
5323
pipeline_processor,
5424
fim_pipeline_processor,
5525
)
56-
self.client = httpx.AsyncClient(timeout=30.0)
26+
# Get the Ollama base URL
27+
config = Config.get_config()
28+
if config is None:
29+
provided_urls = {}
30+
else:
31+
provided_urls = config.provider_urls
32+
self.base_url = provided_urls.get("ollama", "http://localhost:11434/api")
5733

5834
@property
5935
def provider_route_name(self) -> str:
@@ -66,96 +42,16 @@ def _setup_routes(self):
6642

6743
# Native Ollama API routes
6844
@self.router.post(f"/{self.provider_route_name}/api/chat")
69-
async def ollama_chat(
70-
request: Request,
71-
authorization: str = Header(..., description="Bearer token"),
72-
):
73-
if not authorization.startswith("Bearer "):
74-
raise HTTPException(status_code=401, detail="Invalid authorization header")
75-
76-
_api_key = authorization.split(" ")[1]
77-
body = await request.body()
78-
data = json.loads(body)
79-
80-
# Get the Ollama base URL
81-
config = Config.get_config()
82-
base_url = config.provider_urls.get("ollama", "http://localhost:11434")
83-
84-
# Convert chat format to Ollama generate format
85-
messages = []
86-
for msg in data.get("messages", []):
87-
role = msg.get("role", "")
88-
content = msg.get("content", "")
89-
if isinstance(content, list):
90-
# Handle list-based content format
91-
content = " ".join(
92-
part["text"] for part in content if part.get("type") == "text"
93-
)
94-
messages.append({"role": role, "content": content})
95-
96-
ollama_data = {
97-
"model": data.get("model", "").strip(),
98-
"messages": messages,
99-
"stream": True,
100-
"options": data.get("options", {}),
101-
}
102-
103-
# Stream response directly from Ollama
104-
return StreamingResponse(
105-
stream_ollama_response(self.client, f"{base_url}/api/chat", ollama_data),
106-
media_type="application/x-ndjson",
107-
headers={
108-
"Cache-Control": "no-cache",
109-
"Connection": "keep-alive",
110-
},
111-
)
112-
11345
@self.router.post(f"/{self.provider_route_name}/api/generate")
114-
async def ollama_generate(
115-
request: Request,
116-
authorization: str = Header(..., description="Bearer token"),
117-
):
118-
if not authorization.startswith("Bearer "):
119-
raise HTTPException(status_code=401, detail="Invalid authorization header")
120-
121-
_api_key = authorization.split(" ")[1]
122-
body = await request.body()
123-
data = json.loads(body)
124-
125-
# Get the Ollama base URL
126-
config = Config.get_config()
127-
base_url = config.provider_urls.get("ollama", "http://localhost:11434")
128-
129-
# Prepare generate request
130-
ollama_data = {
131-
"model": data.get("model", "").strip(),
132-
"prompt": data.get("prompt", ""),
133-
"stream": True,
134-
"options": data.get("options", {}),
135-
}
136-
137-
# Add any context or system prompt if provided
138-
if "context" in data:
139-
ollama_data["context"] = data["context"]
140-
if "system" in data:
141-
ollama_data["system"] = data["system"]
142-
143-
# Stream response directly from Ollama
144-
return StreamingResponse(
145-
stream_ollama_response(self.client, f"{base_url}/api/generate", ollama_data),
146-
media_type="application/x-ndjson",
147-
headers={
148-
"Cache-Control": "no-cache",
149-
"Connection": "keep-alive",
150-
},
151-
)
152-
15346
# OpenAI-compatible routes for backward compatibility
15447
@self.router.post(f"/{self.provider_route_name}/chat/completions")
15548
@self.router.post(f"/{self.provider_route_name}/completions")
156-
async def create_completion(
157-
request: Request,
158-
authorization: str = Header(..., description="Bearer token"),
159-
):
160-
# Redirect to native Ollama endpoint
161-
return await ollama_chat(request, authorization)
49+
async def create_completion(request: Request):
50+
body = await request.body()
51+
data = json.loads(body)
52+
if "base_url" not in data or not data["base_url"]:
53+
data["base_url"] = self.base_url
54+
55+
is_fim_request = self._is_fim_request(request, data)
56+
stream = await self.complete(data, None, is_fim_request=is_fim_request)
57+
return self._completion_handler.create_response(stream)

tests/providers/ollama/test_ollama_provider.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def test_ollama_chat(mock_config, test_client):
7474
assert sent_data["model"] == "codellama:7b-instruct"
7575
assert sent_data["messages"] == data["messages"]
7676
assert sent_data["options"] == data["options"]
77-
assert sent_data["stream"] is True
7877

7978

8079
@patch("codegate.config.Config.get_config", return_value=MockConfig())
@@ -120,7 +119,6 @@ def test_ollama_generate(mock_config, test_client):
120119
assert sent_data["options"] == data["options"]
121120
assert sent_data["context"] == data["context"]
122121
assert sent_data["system"] == data["system"]
123-
assert sent_data["stream"] is True
124122

125123

126124
@patch("codegate.config.Config.get_config", return_value=MockConfig())
@@ -140,18 +138,3 @@ def test_ollama_error_handling(mock_config, test_client):
140138
content = response.content.decode().strip()
141139
assert "error" in content
142140
assert "Model not found" in content
143-
144-
145-
def test_ollama_auth_required(test_client):
146-
"""Test authentication requirement."""
147-
data = {"model": "codellama:7b-instruct"}
148-
149-
# Test without auth header
150-
response = test_client.post("/ollama/api/generate", json=data)
151-
assert response.status_code == 422
152-
153-
# Test with invalid auth header
154-
response = test_client.post(
155-
"/ollama/api/generate", json=data, headers={"Authorization": "Invalid"}
156-
)
157-
assert response.status_code == 401

0 commit comments

Comments
 (0)