Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Specify the provider in the URL to properly route traffic #96

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ def __init__(self):
completion_handler = LiteLLmShim(adapter)
super().__init__(completion_handler)

@property
def provider_route_name(self) -> str:
return "anthropic"

Comment on lines +16 to +19

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any advantage over just hardcoding the name ? not enough to stop merging, just curious to learn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought it was better to define provider_route_name as an abstracmethod so that we suddenly not forget to specify it for other providers.

def _setup_routes(self):
"""
Sets up the /messages route for the provider as expected by the Anthropic
API. Extracts the API key from the "x-api-key" header and passes it to the
completion handler.
"""

@self.router.post("/messages")
@self.router.post(f"/{self.provider_route_name}/messages")
async def create_message(
request: Request,
x_api_key: str = Header(None),
Expand Down
5 changes: 5 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(self, completion_handler: BaseCompletionHandler):
def _setup_routes(self) -> None:
pass

@property
@abstractmethod
def provider_route_name(self) -> str:
pass

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
return await self._completion_handler.complete(data, api_key)

Expand Down
2 changes: 1 addition & 1 deletion src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, AsyncIterator, Dict

from litellm import ModelResponse
from fastapi.responses import StreamingResponse
from litellm import ModelResponse, acompletion

from codegate.providers.base import BaseCompletionHandler
from codegate.providers.llamacpp.adapter import BaseAdapter
Expand Down
8 changes: 6 additions & 2 deletions src/codegate/providers/llamacpp/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from fastapi import Header, HTTPException, Request
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports

from fastapi import Request

from codegate.providers.base import BaseProvider
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
Expand All @@ -13,13 +13,17 @@ def __init__(self):
completion_handler = LlamaCppCompletionHandler(adapter)
super().__init__(completion_handler)

@property
def provider_route_name(self) -> str:
return "llamacpp"

def _setup_routes(self):
"""
Sets up the /chat route for the provider as expected by the
Llama API. Extracts the API key from the "Authorization" header and
passes it to the completion handler.
"""
@self.router.post("/completion")
@self.router.post(f"/{self.provider_route_name}/completion")
async def create_completion(
request: Request,
):
Expand Down
6 changes: 5 additions & 1 deletion src/codegate/providers/openai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ def __init__(self):
completion_handler = LiteLLmShim(adapter)
super().__init__(completion_handler)

@property
def provider_route_name(self) -> str:
return "openai"

def _setup_routes(self):
"""
Sets up the /chat/completions route for the provider as expected by the
OpenAI API. Extracts the API key from the "Authorization" header and
passes it to the completion handler.
"""

@self.router.post("/chat/completions")
@self.router.post(f"/{self.provider_route_name}/chat/completions")
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
Expand Down
9 changes: 7 additions & 2 deletions tests/providers/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def create_streaming_response(


class MockProvider(BaseProvider):

@property
def provider_route_name(self) -> str:
return 'mock_provider'

def _setup_routes(self) -> None:
@self.router.get("/test")
@self.router.get(f"/{self.provider_route_name}/test")
def test_route():
return {"message": "test"}

Expand Down Expand Up @@ -61,5 +66,5 @@ def test_provider_routes_added(app, registry, mock_completion_handler):
provider = MockProvider(mock_completion_handler)
registry.add_provider("test", provider)

routes = [route for route in app.routes if route.path == "/test"]
routes = [route for route in app.routes if route.path == "/mock_provider/test"]
assert len(routes) == 1