Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit a7bebae

Browse files
Merge pull request #241 from stacklok/ollama-litellm
Use ollama python client for completion
2 parents 1d21c1a + a49514c commit a7bebae

File tree

13 files changed

+672
-690
lines changed

13 files changed

+672
-690
lines changed

poetry.lock

Lines changed: 443 additions & 426 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
authors = []
77

88
[tool.poetry.dependencies]
9-
python = ">=3.11"
9+
python = ">=3.11,<4.0"
1010
click = ">=8.1.0"
1111
PyYAML = ">=6.0.1"
1212
fastapi = ">=0.115.5"
@@ -19,6 +19,7 @@ cryptography = "^44.0.0"
1919
sqlalchemy = "^2.0.28"
2020
greenlet = "^3.0.3"
2121
aiosqlite = "^0.20.0"
22+
ollama = ">=0.4.4"
2223

2324
[tool.poetry.group.dev.dependencies]
2425
pytest = ">=7.4.0"

src/codegate/codegate_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def setup_logging(
7575
structlog.processors.CallsiteParameterAdder(
7676
[
7777
structlog.processors.CallsiteParameter.MODULE,
78+
structlog.processors.CallsiteParameter.PATHNAME,
7879
]
7980
),
8081
]

src/codegate/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"openai": "https://api.openai.com/v1",
2020
"anthropic": "https://api.anthropic.com/v1",
2121
"vllm": "http://localhost:8000", # Base URL without /v1 path
22-
"ollama": "http://localhost:11434/api", # Default Ollama server URL
22+
"ollama": "http://localhost:11434", # Default Ollama server URL
2323
}
2424

2525

src/codegate/llm_utils/llmclient.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Dict, Optional
33

44
import structlog
5-
from litellm import acompletion
5+
from litellm import acompletion, completion
66

77
from codegate.config import Config
88
from codegate.inference import LlamaCppInferenceEngine
@@ -112,18 +112,27 @@ async def _complete_litellm(
112112
if not base_url.endswith("/v1"):
113113
base_url = f"{base_url}/v1"
114114
else:
115-
model = f"{provider}/{model}"
115+
if not model.startswith(f"{provider}/"):
116+
model = f"{provider}/{model}"
116117

117118
try:
118-
response = await acompletion(
119-
model=model,
120-
messages=request["messages"],
121-
api_key=api_key,
122-
temperature=request["temperature"],
123-
base_url=base_url,
124-
response_format=request["response_format"],
125-
)
126-
119+
if provider == "ollama":
120+
response = completion(
121+
model=model,
122+
messages=request["messages"],
123+
api_key=api_key,
124+
temperature=request["temperature"],
125+
base_url=base_url,
126+
)
127+
else:
128+
response = await acompletion(
129+
model=model,
130+
messages=request["messages"],
131+
api_key=api_key,
132+
temperature=request["temperature"],
133+
base_url=base_url,
134+
response_format=request["response_format"],
135+
)
127136
content = response["choices"][0]["message"]["content"]
128137

129138
# Clean up code blocks if present
@@ -133,5 +142,5 @@ async def _complete_litellm(
133142
return json.loads(content)
134143

135144
except Exception as e:
136-
logger.error(f"LiteLLM completion failed: {e}")
145+
logger.error(f"LiteLLM completion failed {provider}/{model} ({content}): {e}")
137146
return {}

src/codegate/pipeline/output.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ async def process_stream(
134134

135135
# Yield all processed chunks
136136
for c in current_chunks:
137-
logger.debug(f"Yielding chunk {c}")
138137
self._store_chunk_content(c)
139138
self._context.buffer.clear()
140139
yield c

src/codegate/providers/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,18 @@ async def complete(
214214
provider-specific format
215215
"""
216216
normalized_request = self._input_normalizer.normalize(data)
217-
streaming = data.get("stream", False)
217+
streaming = normalized_request.get("stream", False)
218218
prompt_db = await self._db_recorder.record_request(
219219
normalized_request, is_fim_request, self.provider_route_name
220220
)
221221

222+
prompt_db_id = prompt_db.id if prompt_db is not None else None
222223
input_pipeline_result = await self._run_input_pipeline(
223224
normalized_request,
224225
api_key,
225226
data.get("base_url"),
226227
is_fim_request,
227-
prompt_id=prompt_db.id,
228+
prompt_id=prompt_db_id,
228229
)
229230
if input_pipeline_result.response:
230231
await self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
@@ -239,7 +240,6 @@ async def complete(
239240
# Execute the completion and translate the response
240241
# This gives us either a single response or a stream of responses
241242
# based on the streaming flag
242-
logger.info(f"Executing completion with {provider_request}")
243243
model_response = await self._completion_handler.execute_completion(
244244
provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request
245245
)
@@ -259,9 +259,7 @@ async def complete(
259259

260260
model_response = self._db_recorder.record_output_stream(prompt_db, model_response)
261261
pipeline_output_stream = await self._run_output_stream_pipeline(
262-
input_pipeline_result.context,
263-
model_response,
264-
is_fim_request=is_fim_request,
262+
input_pipeline_result.context, model_response, is_fim_request=is_fim_request
265263
)
266264
return self._cleanup_after_streaming(pipeline_output_stream, input_pipeline_result.context)
267265

Lines changed: 100 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,33 @@
1-
from typing import Any, Dict
1+
import uuid
2+
from datetime import datetime, timezone
3+
from typing import Any, AsyncIterator, Dict, Union
24

3-
from litellm import ChatCompletionRequest
5+
from litellm import ChatCompletionRequest, ModelResponse
6+
from litellm.types.utils import Delta, StreamingChoices
7+
from ollama import ChatResponse, Message
48

59
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
610

711

812
class OllamaInputNormalizer(ModelInputNormalizer):
9-
def __init__(self):
10-
super().__init__()
1113

1214
def normalize(self, data: Dict) -> ChatCompletionRequest:
1315
"""
1416
Normalize the input data to the format expected by Ollama.
1517
"""
1618
# Make a copy of the data to avoid modifying the original and normalize the message content
1719
normalized_data = self._normalize_content_messages(data)
20+
normalized_data["model"] = data.get("model", "").strip()
1821
normalized_data["options"] = data.get("options", {})
1922

20-
# Add any context or system prompt if provided
21-
if "context" in data:
22-
normalized_data["context"] = data["context"]
23-
if "system" in data:
24-
normalized_data["system"] = data["system"]
23+
if "prompt" in normalized_data:
24+
normalized_data["messages"] = [
25+
{"content": normalized_data.pop("prompt"), "role": "user"}
26+
]
2527

26-
# Format the model name
27-
if "model" in normalized_data:
28-
normalized_data["model"] = data["model"].strip()
29-
30-
# Ensure the base_url ends with /api if provided
31-
if "base_url" in normalized_data:
32-
base_url = normalized_data["base_url"].rstrip("/")
33-
if not base_url.endswith("/api"):
34-
normalized_data["base_url"] = f"{base_url}/api"
28+
# In Ollama force the stream to be True. Continue is not setting this parameter and
29+
# most of our functionality is for streaming completions.
30+
normalized_data["stream"] = True
3531

3632
return ChatCompletionRequest(**normalized_data)
3733

@@ -42,18 +38,98 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
4238
return data
4339

4440

41+
class OLlamaToModel(AsyncIterator[ModelResponse]):
42+
def __init__(self, ollama_response: AsyncIterator[ChatResponse]):
43+
self.ollama_response = ollama_response
44+
self._aiter = ollama_response.__aiter__()
45+
46+
def __aiter__(self):
47+
return self
48+
49+
async def __anext__(self):
50+
try:
51+
chunk = await self._aiter.__anext__()
52+
if not isinstance(chunk, ChatResponse):
53+
return chunk
54+
55+
finish_reason = None
56+
role = "assistant"
57+
58+
# Convert the datetime object to a timestamp in seconds
59+
datetime_obj = datetime.fromisoformat(chunk.created_at)
60+
timestamp_seconds = int(datetime_obj.timestamp())
61+
62+
if chunk.done:
63+
finish_reason = "stop"
64+
role = None
65+
66+
model_response = ModelResponse(
67+
id=f"ollama-chat-{str(uuid.uuid4())}",
68+
created=timestamp_seconds,
69+
model=chunk.model,
70+
object="chat.completion.chunk",
71+
choices=[
72+
StreamingChoices(
73+
finish_reason=finish_reason,
74+
index=0,
75+
delta=Delta(content=chunk.message.content, role=role),
76+
logprobs=None,
77+
)
78+
],
79+
)
80+
return model_response
81+
except StopAsyncIteration:
82+
raise StopAsyncIteration
83+
84+
85+
class ModelToOllama(AsyncIterator[ChatResponse]):
86+
87+
def __init__(self, normalized_reply: AsyncIterator[ModelResponse]):
88+
self.normalized_reply = normalized_reply
89+
self._aiter = normalized_reply.__aiter__()
90+
91+
def __aiter__(self):
92+
return self
93+
94+
async def __anext__(self) -> Union[ChatResponse]:
95+
try:
96+
chunk = await self._aiter.__anext__()
97+
if not isinstance(chunk, ModelResponse):
98+
return chunk
99+
# Convert the timestamp to a datetime object
100+
datetime_obj = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
101+
created_at = datetime_obj.isoformat()
102+
103+
message = chunk.choices[0].delta.content
104+
done = False
105+
if chunk.choices[0].finish_reason == "stop":
106+
done = True
107+
message = ""
108+
109+
# Convert the model response to an Ollama response
110+
ollama_response = ChatResponse(
111+
model=chunk.model,
112+
created_at=created_at,
113+
done=done,
114+
message=Message(content=message, role="assistant"),
115+
)
116+
return ollama_response
117+
except StopAsyncIteration:
118+
raise StopAsyncIteration
119+
120+
45121
class OllamaOutputNormalizer(ModelOutputNormalizer):
46122
def __init__(self):
47123
super().__init__()
48124

49125
def normalize_streaming(
50126
self,
51-
model_reply: Any,
52-
) -> Any:
127+
model_reply: AsyncIterator[ChatResponse],
128+
) -> AsyncIterator[ModelResponse]:
53129
"""
54130
Pass through Ollama response
55131
"""
56-
return model_reply
132+
return OLlamaToModel(model_reply)
57133

58134
def normalize(self, model_reply: Any) -> Any:
59135
"""
@@ -68,10 +144,9 @@ def denormalize(self, normalized_reply: Any) -> Any:
68144
return normalized_reply
69145

70146
def denormalize_streaming(
71-
self,
72-
normalized_reply: Any,
73-
) -> Any:
147+
self, normalized_reply: AsyncIterator[ModelResponse]
148+
) -> AsyncIterator[ChatResponse]:
74149
"""
75150
Pass through Ollama response
76151
"""
77-
return normalized_reply
152+
return ModelToOllama(normalized_reply)

0 commit comments

Comments
 (0)