Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 4d47146

Browse files
authored
Implement provider interface and OpenAI and Anthropic providers (#66)
Implements an interface for building CodeGate providers. For now we use LiteLLM, but we might switch in the future, so the LiteLLM calls are abstracted away in the `litellmshim` module in two classes: - `BaseAdapter` which provides means for reusing LiteLLM adapters with the same interface. - `LiteLLmShim` that actually calls into liteLLM's completion and calls the adapter before completion to convert into liteLLM's format and then back after completion Using those interfaces, implements an OpenAI and an Anthropic provider. With this patch, codegate allows to pass through requests towards OpenAI and Anthropic. Next, we'll build a pipeline interface to modify the inputs and outputs.
1 parent dbb814a commit 4d47146

File tree

17 files changed

+422
-8
lines changed

17 files changed

+422
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dev = [
2020
"bandit>=1.7.10",
2121
"build>=1.0.0",
2222
"wheel>=0.40.0",
23+
"litellm>=1.52.11",
2324
]
2425

2526
[build-system]

src/codegate/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
try:
66
__version__ = metadata.version("codegate")
7+
__description__ = metadata.metadata("codegate")["Summary"]
78
except metadata.PackageNotFoundError: # pragma: no cover
89
__version__ = "unknown"
10+
__description__ = "codegate"
911

1012
from .config import Config, ConfigurationError
1113
from .logging import setup_logging

src/codegate/cli.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .config import Config, ConfigurationError, LogFormat, LogLevel
1111
from .logging import setup_logging
12+
from .server import init_app
1213

1314

1415
def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int:
@@ -65,7 +66,6 @@ def serve(
6566
config: Optional[Path],
6667
) -> None:
6768
"""Start the codegate server."""
68-
6969
try:
7070
# Load configuration with priority resolution
7171
cfg = Config.load(
@@ -79,11 +79,6 @@ def serve(
7979
setup_logging(cfg.log_level, cfg.log_format)
8080
logger = logging.getLogger(__name__)
8181

82-
logger.info("This is an info message")
83-
logger.debug("This is a debug message")
84-
logger.error("This is an error message")
85-
logger.warning("This is a warning message")
86-
8782
logger.info(
8883
"Starting server",
8984
extra={
@@ -94,13 +89,25 @@ def serve(
9489
},
9590
)
9691

97-
# TODO: Jakub Implement actual server logic here
98-
logger.info("Server started successfully")
92+
app = init_app()
93+
94+
import uvicorn
95+
96+
uvicorn.run(
97+
app,
98+
host=cfg.host,
99+
port=cfg.port,
100+
log_level=cfg.log_level.value.lower(),
101+
log_config=None, # Default logging configuration
102+
)
99103

104+
except KeyboardInterrupt:
105+
logger.info("Shutting down server")
100106
except ConfigurationError as e:
101107
click.echo(f"Configuration error: {e}", err=True)
102108
sys.exit(1)
103109
except Exception as e:
110+
logger.exception("Unexpected error occurred")
104111
click.echo(f"Error: {e}", err=True)
105112
sys.exit(1)
106113

src/codegate/providers/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .anthropic.provider import AnthropicProvider
2+
from .base import BaseProvider
3+
from .openai.provider import OpenAIProvider
4+
from .registry import ProviderRegistry
5+
6+
__all__ = [
7+
"BaseProvider",
8+
"ProviderRegistry",
9+
"OpenAIProvider",
10+
"AnthropicProvider",
11+
]

src/codegate/providers/anthropic/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Dict, Optional
2+
3+
from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse
4+
from litellm.adapters.anthropic_adapter import (
5+
AnthropicAdapter as LitellmAnthropicAdapter,
6+
)
7+
from litellm.types.llms.anthropic import AnthropicResponse
8+
9+
from ..base import StreamGenerator
10+
from ..litellmshim import anthropic_stream_generator
11+
from ..litellmshim.litellmshim import BaseAdapter
12+
13+
14+
class AnthropicAdapter(BaseAdapter):
15+
"""
16+
LiteLLM's adapter class interface is used to translate between the Anthropic data
17+
format and the underlying model. The AnthropicAdapter class contains the actual
18+
implementation of the interface methods, we just forward the calls to it.
19+
"""
20+
def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator):
21+
self.litellm_anthropic_adapter = LitellmAnthropicAdapter()
22+
super().__init__(stream_generator)
23+
24+
def translate_completion_input_params(
25+
self,
26+
completion_request: Dict,
27+
) -> Optional[ChatCompletionRequest]:
28+
return self.litellm_anthropic_adapter.translate_completion_input_params(
29+
completion_request
30+
)
31+
32+
def translate_completion_output_params(
33+
self, response: ModelResponse
34+
) -> Optional[AnthropicResponse]:
35+
return self.litellm_anthropic_adapter.translate_completion_output_params(
36+
response
37+
)
38+
39+
def translate_completion_output_params_streaming(
40+
self, completion_stream: Any
41+
) -> AdapterCompletionStreamWrapper | None:
42+
return (
43+
self.litellm_anthropic_adapter.translate_completion_output_params_streaming(
44+
completion_stream
45+
)
46+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
3+
from fastapi import Header, HTTPException, Request
4+
5+
from ..base import BaseProvider
6+
from ..litellmshim.litellmshim import LiteLLmShim
7+
from .adapter import AnthropicAdapter
8+
9+
10+
class AnthropicProvider(BaseProvider):
11+
def __init__(self):
12+
adapter = AnthropicAdapter()
13+
completion_handler = LiteLLmShim(adapter)
14+
super().__init__(completion_handler)
15+
16+
def _setup_routes(self):
17+
"""
18+
Sets up the /messages route for the provider as expected by the Anthropic
19+
API. Extracts the API key from the "x-api-key" header and passes it to the
20+
completion handler.
21+
"""
22+
@self.router.post("/messages")
23+
async def create_message(
24+
request: Request,
25+
x_api_key: str = Header(None),
26+
):
27+
if x_api_key == "":
28+
raise HTTPException(status_code=401, detail="No API key provided")
29+
30+
body = await request.body()
31+
data = json.loads(body)
32+
33+
stream = await self.complete(data, x_api_key)
34+
return self._completion_handler.create_streaming_response(stream)

src/codegate/providers/base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, AsyncIterator, Callable, Dict
3+
4+
from fastapi import APIRouter
5+
from fastapi.responses import StreamingResponse
6+
7+
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
8+
9+
10+
class BaseCompletionHandler(ABC):
11+
"""
12+
The completion handler is responsible for executing the completion request
13+
and creating the streaming response.
14+
"""
15+
16+
@abstractmethod
17+
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
18+
pass
19+
20+
@abstractmethod
21+
def create_streaming_response(
22+
self, stream: AsyncIterator[Any]
23+
) -> StreamingResponse:
24+
pass
25+
26+
27+
class BaseProvider(ABC):
28+
"""
29+
The provider class is responsible for defining the API routes and
30+
calling the completion method using the completion handler.
31+
"""
32+
33+
def __init__(self, completion_handler: BaseCompletionHandler):
34+
self.router = APIRouter()
35+
self._completion_handler = completion_handler
36+
self._setup_routes()
37+
38+
@abstractmethod
39+
def _setup_routes(self) -> None:
40+
pass
41+
42+
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
43+
return await self._completion_handler.complete(data, api_key)
44+
45+
def get_routes(self) -> APIRouter:
46+
return self.router
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .adapter import BaseAdapter
2+
from .generators import anthropic_stream_generator, sse_stream_generator
3+
from .litellmshim import LiteLLmShim
4+
5+
__all__ = [
6+
"sse_stream_generator",
7+
"anthropic_stream_generator",
8+
"LiteLLmShim",
9+
"BaseAdapter",
10+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Optional
3+
4+
from litellm import ChatCompletionRequest, ModelResponse
5+
6+
from codegate.providers.base import StreamGenerator
7+
8+
9+
class BaseAdapter(ABC):
10+
"""
11+
The adapter class is responsible for translating input and output
12+
parameters between the provider-specific on-the-wire API and the
13+
underlying model. We use LiteLLM's ChatCompletionRequest and ModelResponse
14+
is our data model.
15+
16+
The methods in this class implement LiteLLM's Adapter interface and are
17+
not our own. This is to allow us to use LiteLLM's adapter classes as a
18+
drop-in replacement for our own adapters.
19+
"""
20+
21+
def __init__(self, stream_generator: StreamGenerator):
22+
self.stream_generator = stream_generator
23+
24+
@abstractmethod
25+
def translate_completion_input_params(
26+
self, kwargs: Dict
27+
) -> Optional[ChatCompletionRequest]:
28+
"""Convert input parameters to LiteLLM's ChatCompletionRequest format"""
29+
pass
30+
31+
@abstractmethod
32+
def translate_completion_output_params(self, response: ModelResponse) -> Any:
33+
"""Convert non-streaming response from LiteLLM ModelResponse format"""
34+
pass
35+
36+
@abstractmethod
37+
def translate_completion_output_params_streaming(
38+
self, completion_stream: Any
39+
) -> Any:
40+
"""
41+
Convert streaming response from LiteLLM format to a format that
42+
can be passed to a stream generator and to the client.
43+
"""
44+
pass
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
from typing import Any, AsyncIterator
3+
4+
# Since different providers typically use one of these formats for streaming
5+
# responses, we have a single stream generator for each format that is then plugged
6+
# into the adapter.
7+
8+
async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
9+
"""OpenAI-style SSE format"""
10+
try:
11+
async for chunk in stream:
12+
if hasattr(chunk, "model_dump_json"):
13+
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
14+
try:
15+
yield f"data:{chunk}\n\n"
16+
except Exception as e:
17+
yield f"data:{str(e)}\n\n"
18+
except Exception as e:
19+
yield f"data: {str(e)}\n\n"
20+
finally:
21+
yield "data: [DONE]\n\n"
22+
23+
24+
async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
25+
"""Anthropic-style SSE format"""
26+
try:
27+
async for chunk in stream:
28+
event_type = chunk.get("type")
29+
try:
30+
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
31+
except Exception as e:
32+
yield f"event: {event_type}\ndata:{str(e)}\n\n"
33+
except Exception as e:
34+
yield f"data: {str(e)}\n\n"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Any, AsyncIterator, Dict
2+
3+
from fastapi.responses import StreamingResponse
4+
from litellm import ModelResponse, acompletion
5+
6+
from ..base import BaseCompletionHandler
7+
from .adapter import BaseAdapter
8+
9+
10+
class LiteLLmShim(BaseCompletionHandler):
11+
"""
12+
LiteLLM Shim is a wrapper around LiteLLM's API that allows us to use it with
13+
our own completion handler interface without exposing the underlying
14+
LiteLLM API.
15+
"""
16+
def __init__(self, adapter: BaseAdapter):
17+
self._adapter = adapter
18+
19+
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
20+
"""
21+
Translate the input parameters to LiteLLM's format using the adapter and
22+
call the LiteLLM API. Then translate the response back to our format using
23+
the adapter.
24+
"""
25+
data["api_key"] = api_key
26+
completion_request = self._adapter.translate_completion_input_params(data)
27+
if completion_request is None:
28+
raise Exception("Couldn't translate the request")
29+
30+
response = await acompletion(**completion_request)
31+
32+
if isinstance(response, ModelResponse):
33+
return self._adapter.translate_completion_output_params(response)
34+
return self._adapter.translate_completion_output_params_streaming(response)
35+
36+
def create_streaming_response(
37+
self, stream: AsyncIterator[Any]
38+
) -> StreamingResponse:
39+
"""
40+
Create a streaming response from a stream generator. The StreamingResponse
41+
is the format that FastAPI expects for streaming responses.
42+
"""
43+
return StreamingResponse(
44+
self._adapter.stream_generator(stream),
45+
headers={
46+
"Cache-Control": "no-cache",
47+
"Connection": "keep-alive",
48+
"Transfer-Encoding": "chunked",
49+
},
50+
status_code=200,
51+
)

src/codegate/providers/openai/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Any, AsyncIterator, Dict, Optional
2+
3+
from litellm import ChatCompletionRequest, ModelResponse
4+
5+
from ..base import StreamGenerator
6+
from ..litellmshim import sse_stream_generator
7+
from ..litellmshim.litellmshim import BaseAdapter
8+
9+
10+
class OpenAIAdapter(BaseAdapter):
11+
"""
12+
This is just a wrapper around LiteLLM's adapter class interface that passes
13+
through the input and output as-is - LiteLLM's API expects OpenAI's API
14+
format.
15+
"""
16+
def __init__(self, stream_generator: StreamGenerator = sse_stream_generator):
17+
super().__init__(stream_generator)
18+
19+
def translate_completion_input_params(
20+
self, kwargs: Dict
21+
) -> Optional[ChatCompletionRequest]:
22+
try:
23+
return ChatCompletionRequest(**kwargs)
24+
except Exception as e:
25+
raise ValueError(f"Invalid completion parameters: {str(e)}")
26+
27+
def translate_completion_output_params(self, response: ModelResponse) -> Any:
28+
return response
29+
30+
def translate_completion_output_params_streaming(
31+
self, completion_stream: AsyncIterator[ModelResponse]
32+
) -> AsyncIterator[ModelResponse]:
33+
return completion_stream

0 commit comments

Comments
 (0)