diff --git a/Dockerfile b/Dockerfile index 7d126919..306780b0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -102,7 +102,7 @@ ENV PYTHONPATH=/app/src ENV CODEGATE_VLLM_URL=https://inference.codegate.ai ENV CODEGATE_OPENAI_URL= ENV CODEGATE_ANTHROPIC_URL= -ENV CODEGATE_OLLAMA_URL= +ENV CODEGATE_OLLAMA_URL=http://host.docker.internal:11434 ENV CODEGATE_APP_LOG_LEVEL=WARNING ENV CODEGATE_LOG_FORMAT=TEXT diff --git a/src/codegate/llm_utils/llmclient.py b/src/codegate/llm_utils/llmclient.py index a8bcdd38..f2b301c4 100644 --- a/src/codegate/llm_utils/llmclient.py +++ b/src/codegate/llm_utils/llmclient.py @@ -2,7 +2,8 @@ from typing import Any, Dict, Optional import structlog -from litellm import acompletion, completion +from litellm import acompletion +from ollama import Client as OllamaClient from codegate.config import Config from codegate.inference import LlamaCppInferenceEngine @@ -117,13 +118,14 @@ async def _complete_litellm( try: if provider == "ollama": - response = completion( + model = model.split("/")[-1] + response = OllamaClient(host=base_url).chat( model=model, messages=request["messages"], - api_key=api_key, - temperature=request["temperature"], - base_url=base_url, + format="json", + options={"temperature": request["temperature"]}, ) + content = response.message.content else: response = await acompletion( model=model, @@ -133,7 +135,7 @@ async def _complete_litellm( base_url=base_url, response_format=request["response_format"], ) - content = response["choices"][0]["message"]["content"] + content = response["choices"][0]["message"]["content"] # Clean up code blocks if present if content.startswith("```"): @@ -142,5 +144,5 @@ async def _complete_litellm( return json.loads(content) except Exception as e: - logger.error(f"LiteLLM completion failed {provider}/{model} ({content}): {e}") + logger.error(f"LiteLLM completion failed {model} ({content}): {e}") return {} diff --git a/src/codegate/pipeline/extract_snippets/output.py b/src/codegate/pipeline/extract_snippets/output.py index a9b67db8..cd7391f8 100644 --- a/src/codegate/pipeline/extract_snippets/output.py +++ b/src/codegate/pipeline/extract_snippets/output.py @@ -87,7 +87,7 @@ def _split_chunk_at_code_end(self, content: str) -> tuple[str, str]: if line.strip() == "```": # Return content up to and including ```, and the rest before = "\n".join(lines[: i + 1]) - after = "\n".join(lines[i + 1:]) + after = "\n".join(lines[i + 1 :]) return before, after return content, "" diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index d5ed5c59..49fbc103 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -5,7 +5,6 @@ from litellm import ChatCompletionRequest from ollama import AsyncClient, ChatResponse, GenerateResponse -from codegate.config import Config from codegate.providers.base import BaseCompletionHandler logger = structlog.get_logger("codegate") @@ -27,13 +26,7 @@ async def ollama_stream_generator( class OllamaShim(BaseCompletionHandler): - def __init__(self): - config = Config.get_config() - if config is None: - provided_urls = {} - else: - provided_urls = config.provider_urls - base_url = provided_urls.get("ollama", "http://localhost:11434/") + def __init__(self, base_url): self.client = AsyncClient(host=base_url, timeout=300) async def execute_completion( diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index d36f4110..975d2295 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -3,6 +3,7 @@ from fastapi import Request +from codegate.config import Config from codegate.pipeline.base import SequentialPipelineProcessor from codegate.pipeline.output import OutputPipelineProcessor from codegate.providers.base import BaseProvider @@ -18,7 +19,13 @@ def __init__( output_pipeline_processor: Optional[OutputPipelineProcessor] = None, fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): - completion_handler = OllamaShim() + config = Config.get_config() + if config is None: + provided_urls = {} + else: + provided_urls = config.provider_urls + self.base_url = provided_urls.get("ollama", "http://localhost:11434/") + completion_handler = OllamaShim(self.base_url) super().__init__( OllamaInputNormalizer(), OllamaOutputNormalizer(), @@ -46,6 +53,9 @@ def _setup_routes(self): async def create_completion(request: Request): body = await request.body() data = json.loads(body) + # `base_url` is used in the providers pipeline to do the packages lookup. + # Force it to be the one that comes in the configuration. + data["base_url"] = self.base_url is_fim_request = self._is_fim_request(request, data) stream = await self.complete(data, api_key=None, is_fim_request=is_fim_request) diff --git a/tests/providers/ollama/test_ollama_completion_handler.py b/tests/providers/ollama/test_ollama_completion_handler.py index 4e0908d0..cc32e915 100644 --- a/tests/providers/ollama/test_ollama_completion_handler.py +++ b/tests/providers/ollama/test_ollama_completion_handler.py @@ -19,7 +19,7 @@ def mock_client(): @pytest.fixture def handler(mock_client): - ollama_shim = OllamaShim() + ollama_shim = OllamaShim("http://ollama:11434") ollama_shim.client = mock_client return ollama_shim