Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Make sure we can speak with ollama localhosted from container #275

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ ENV PYTHONPATH=/app/src
ENV CODEGATE_VLLM_URL=https://inference.codegate.ai
ENV CODEGATE_OPENAI_URL=
ENV CODEGATE_ANTHROPIC_URL=
ENV CODEGATE_OLLAMA_URL=
ENV CODEGATE_OLLAMA_URL=http://host.docker.internal:11434
ENV CODEGATE_APP_LOG_LEVEL=WARNING
ENV CODEGATE_LOG_FORMAT=TEXT

Expand Down
16 changes: 9 additions & 7 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, Dict, Optional

import structlog
from litellm import acompletion, completion
from litellm import acompletion
from ollama import Client as OllamaClient

from codegate.config import Config
from codegate.inference import LlamaCppInferenceEngine
Expand Down Expand Up @@ -117,13 +118,14 @@ async def _complete_litellm(

try:
if provider == "ollama":
response = completion(
model = model.split("/")[-1]
response = OllamaClient(host=base_url).chat(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
format="json",
options={"temperature": request["temperature"]},
)
content = response.message.content
else:
response = await acompletion(
model=model,
Expand All @@ -133,7 +135,7 @@ async def _complete_litellm(
base_url=base_url,
response_format=request["response_format"],
)
content = response["choices"][0]["message"]["content"]
content = response["choices"][0]["message"]["content"]

# Clean up code blocks if present
if content.startswith("```"):
Expand All @@ -142,5 +144,5 @@ async def _complete_litellm(
return json.loads(content)

except Exception as e:
logger.error(f"LiteLLM completion failed {provider}/{model} ({content}): {e}")
logger.error(f"LiteLLM completion failed {model} ({content}): {e}")
return {}
2 changes: 1 addition & 1 deletion src/codegate/pipeline/extract_snippets/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _split_chunk_at_code_end(self, content: str) -> tuple[str, str]:
if line.strip() == "```":
# Return content up to and including ```, and the rest
before = "\n".join(lines[: i + 1])
after = "\n".join(lines[i + 1:])
after = "\n".join(lines[i + 1 :])
return before, after
return content, ""

Expand Down
9 changes: 1 addition & 8 deletions src/codegate/providers/ollama/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from litellm import ChatCompletionRequest
from ollama import AsyncClient, ChatResponse, GenerateResponse

from codegate.config import Config
from codegate.providers.base import BaseCompletionHandler

logger = structlog.get_logger("codegate")
Expand All @@ -27,13 +26,7 @@ async def ollama_stream_generator(

class OllamaShim(BaseCompletionHandler):

def __init__(self):
config = Config.get_config()
if config is None:
provided_urls = {}
else:
provided_urls = config.provider_urls
base_url = provided_urls.get("ollama", "http://localhost:11434/")
def __init__(self, base_url):
self.client = AsyncClient(host=base_url, timeout=300)

async def execute_completion(
Expand Down
12 changes: 11 additions & 1 deletion src/codegate/providers/ollama/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fastapi import Request

from codegate.config import Config
from codegate.pipeline.base import SequentialPipelineProcessor
from codegate.pipeline.output import OutputPipelineProcessor
from codegate.providers.base import BaseProvider
Expand All @@ -18,7 +19,13 @@ def __init__(
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
):
completion_handler = OllamaShim()
config = Config.get_config()
if config is None:
provided_urls = {}
else:
provided_urls = config.provider_urls
self.base_url = provided_urls.get("ollama", "http://localhost:11434/")
completion_handler = OllamaShim(self.base_url)
super().__init__(
OllamaInputNormalizer(),
OllamaOutputNormalizer(),
Expand Down Expand Up @@ -46,6 +53,9 @@ def _setup_routes(self):
async def create_completion(request: Request):
body = await request.body()
data = json.loads(body)
# `base_url` is used in the providers pipeline to do the packages lookup.
# Force it to be the one that comes in the configuration.
data["base_url"] = self.base_url

is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, api_key=None, is_fim_request=is_fim_request)
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/ollama/test_ollama_completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def mock_client():

@pytest.fixture
def handler(mock_client):
ollama_shim = OllamaShim()
ollama_shim = OllamaShim("http://ollama:11434")
ollama_shim.client = mock_client
return ollama_shim

Expand Down
Loading