Skip to content

Commit 0421bd8

Browse files
committed
Ported llamacpp.
1 parent c843f6b commit 0421bd8

File tree

5 files changed

+260
-36
lines changed

5 files changed

+260
-36
lines changed

src/codegate/providers/llamacpp/completion_handler.py

+69-31
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,46 @@
1111
from codegate.config import Config
1212
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1313
from codegate.providers.base import BaseCompletionHandler
14+
from codegate.types.openai import (
15+
stream_generator,
16+
LegacyCompletion,
17+
StreamingChatCompletion,
18+
)
19+
20+
21+
# async def llamacpp_stream_generator(
22+
# stream: AsyncIterator[CreateChatCompletionStreamResponse],
23+
# ) -> AsyncIterator[str]:
24+
# """OpenAI-style SSE format"""
25+
# try:
26+
# async for chunk in stream:
27+
# chunk = json.dumps(chunk)
28+
# try:
29+
# yield f"data:{chunk}\n\n"
30+
# except Exception as e:
31+
# yield f"data:{str(e)}\n\n"
32+
# except Exception as e:
33+
# yield f"data: {str(e)}\n\n"
34+
# finally:
35+
# yield "data: [DONE]\n\n"
1436

1537

16-
async def llamacpp_stream_generator(
17-
stream: AsyncIterator[CreateChatCompletionStreamResponse],
18-
) -> AsyncIterator[str]:
19-
"""OpenAI-style SSE format"""
20-
try:
21-
async for chunk in stream:
22-
chunk = json.dumps(chunk)
23-
try:
24-
yield f"data:{chunk}\n\n"
25-
except Exception as e:
26-
yield f"data:{str(e)}\n\n"
27-
except Exception as e:
28-
yield f"data: {str(e)}\n\n"
29-
finally:
30-
yield "data: [DONE]\n\n"
31-
32-
33-
async def convert_to_async_iterator(
34-
sync_iterator: Iterator[CreateChatCompletionStreamResponse],
35-
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
38+
async def completion_to_async_iterator(
39+
sync_iterator: Iterator[dict],
40+
) -> AsyncIterator[LegacyCompletion]:
3641
"""
3742
Convert a synchronous iterator to an asynchronous iterator. This makes the logic easier
3843
because both the pipeline and the completion handler can use async iterators.
3944
"""
4045
for item in sync_iterator:
41-
yield item
42-
await asyncio.sleep(0)
46+
yield LegacyCompletion(**item)
47+
48+
49+
async def chat_to_async_iterator(
50+
sync_iterator: Iterator[dict],
51+
) -> AsyncIterator[StreamingChatCompletion]:
52+
for item in sync_iterator:
53+
yield StreamingChatCompletion(**item)
4354

4455

4556
class LlamaCppCompletionHandler(BaseCompletionHandler):
@@ -57,33 +68,60 @@ async def execute_completion(
5768
"""
5869
Execute the completion request with inference engine API
5970
"""
60-
model_path = f"{request['base_url']}/{request['model']}.gguf"
71+
model_path = f"{base_url}/{request.get_model()}.gguf"
6172

6273
# Create a copy of the request dict and remove stream_options
6374
# Reason - Request error as JSON:
6475
# {'error': "Llama.create_completion() got an unexpected keyword argument 'stream_options'"}
65-
request_dict = dict(request)
66-
request_dict.pop("stream_options", None)
67-
# Remove base_url from the request dict. We use this field as a standard across
68-
# all providers to specify the base URL of the model.
69-
request_dict.pop("base_url", None)
70-
7176
if is_fim_request:
77+
request_dict = request.dict(exclude={
78+
"best_of",
79+
"frequency_pentalty",
80+
"n",
81+
"stream_options",
82+
"user",
83+
})
84+
7285
response = await self.inference_engine.complete(
7386
model_path,
7487
Config.get_config().chat_model_n_ctx,
7588
Config.get_config().chat_model_n_gpu_layers,
7689
**request_dict,
7790
)
91+
92+
if stream:
93+
return completion_to_async_iterator(response)
94+
return LegacyCompletion(**response)
7895
else:
96+
request_dict = request.dict(exclude={
97+
"audio",
98+
"frequency_pentalty",
99+
"include_reasoning",
100+
"metadata",
101+
"max_completion_tokens",
102+
"modalities",
103+
"n",
104+
"parallel_tool_calls",
105+
"prediction",
106+
"prompt",
107+
"reasoning_effort",
108+
"service_tier",
109+
"store",
110+
"stream_options",
111+
"user",
112+
})
113+
79114
response = await self.inference_engine.chat(
80115
model_path,
81116
Config.get_config().chat_model_n_ctx,
82117
Config.get_config().chat_model_n_gpu_layers,
83118
**request_dict,
84119
)
85120

86-
return convert_to_async_iterator(response) if stream else response
121+
if stream:
122+
return chat_to_async_iterator(response)
123+
else:
124+
return StreamingChatCompletion(**response)
87125

88126
def _create_streaming_response(
89127
self,
@@ -95,7 +133,7 @@ def _create_streaming_response(
95133
is the format that FastAPI expects for streaming responses.
96134
"""
97135
return StreamingResponse(
98-
llamacpp_stream_generator(stream),
136+
stream_generator(stream),
99137
headers={
100138
"Cache-Control": "no-cache",
101139
"Connection": "keep-alive",

src/codegate/providers/llamacpp/provider.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from codegate.providers.base import BaseProvider, ModelFetchError
1313
from codegate.providers.fim_analyzer import FIMAnalyzer
1414
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
15+
from codegate.types.openai import (
16+
ChatCompletionRequest,
17+
LegacyCompletionRequest,
18+
)
19+
1520

1621
logger = structlog.get_logger("codegate")
1722

@@ -21,6 +26,10 @@ def __init__(
2126
self,
2227
pipeline_factory: PipelineFactory,
2328
):
29+
if self._get_base_url() != "":
30+
self.base_url = self._get_base_url()
31+
else:
32+
self.base_url = "./codegate_volume/models"
2433
completion_handler = LlamaCppCompletionHandler()
2534
super().__init__(
2635
None,
@@ -83,17 +92,32 @@ def _setup_routes(self):
8392
"""
8493

8594
@self.router.post(f"/{self.provider_route_name}/completions")
95+
@DetectClient()
96+
async def create_completion(
97+
request: Request,
98+
):
99+
body = await request.body()
100+
print(body)
101+
req = LegacyCompletionRequest.model_validate_json(body)
102+
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
103+
return await self.process_request(
104+
req,
105+
None,
106+
is_fim_request,
107+
request.state.detected_client,
108+
)
109+
86110
@self.router.post(f"/{self.provider_route_name}/chat/completions")
87111
@DetectClient()
88112
async def create_completion(
89113
request: Request,
90114
):
91115
body = await request.body()
92-
data = json.loads(body)
93-
data["base_url"] = Config.get_config().model_base_path
94-
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
116+
print(body)
117+
req = ChatCompletionRequest.model_validate_json(body)
118+
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
95119
return await self.process_request(
96-
data,
120+
req,
97121
None,
98122
is_fim_request,
99123
request.state.detected_client,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import (
2+
Iterable,
3+
List,
4+
)
5+
6+
import pydantic
7+
8+
9+
class GenerateContentRequest(pydantic.BaseModel):
10+
model: str | None = None
11+
contents: List[Content] | None = None
12+
config: Config | None = None
13+
14+
def get_messages(self) -> Iterable[Content]:
15+
if self.content is not None:
16+
for content in self.content:
17+
yield content

src/codegate/types/openai/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
StreamingChatCompletion,
2323
ToolCall,
2424
Usage,
25-
VllmMessageError,
2625
)
2726

2827
from ._request_models import (
@@ -54,3 +53,13 @@
5453
from ._shared_models import (
5554
ServiceTier,
5655
)
56+
57+
from ._legacy_models import (
58+
LegacyCompletionRequest,
59+
LegacyCompletionTokenDetails,
60+
LegacyPromptTokenDetails,
61+
LegacyUsage,
62+
LegacyLogProbs,
63+
LegacyMessage,
64+
LegacyCompletion,
65+
)
+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import (
2+
Any,
3+
Iterable,
4+
List,
5+
Literal,
6+
)
7+
8+
import pydantic
9+
10+
from ._response_models import (
11+
Usage,
12+
)
13+
from ._request_models import (
14+
StreamOption,
15+
Message,
16+
)
17+
18+
19+
class LegacyCompletionRequest(pydantic.BaseModel):
20+
prompt: str | None = None
21+
model: str
22+
best_of: int | None = 1
23+
echo: bool | None = False
24+
frequency_pentalty: float | None = 0.0
25+
logit_bias: dict | None = None
26+
logprobs: int | None = None
27+
max_tokens: int | None = None
28+
n: int | None = None
29+
presence_penalty: float | None = 0.0
30+
seed: int | None = None
31+
stop: str | List[Any] | None = None
32+
stream: bool | None = False
33+
stream_options: StreamOption | None = None
34+
suffix: str | None = None
35+
temperature: float | None = 1.0
36+
top_p: float | None = 1.0
37+
user: str | None = None
38+
39+
def get_stream(self) -> bool:
40+
return self.stream
41+
42+
def get_model(self) -> str:
43+
return self.model
44+
45+
def get_messages(self) -> Iterable[Message]:
46+
yield self
47+
48+
def get_content(self) -> Iterable[Any]:
49+
yield self
50+
51+
def get_text(self) -> str | None:
52+
return self.prompt
53+
54+
def set_text(self, text) -> None:
55+
self.prompt = text
56+
57+
def first_message(self) -> Message | None:
58+
return self
59+
60+
def last_user_message(self) -> tuple[Message, int] | None:
61+
return self, 0
62+
63+
def last_user_block(self) -> Iterable[tuple[Message, int]]:
64+
yield self, 0
65+
66+
def get_system_prompt(self) -> Iterable[str]:
67+
yield self.get_text()
68+
69+
def set_system_prompt(self, text) -> None:
70+
self.set_text(text)
71+
72+
def add_system_prompt(self, text, sep="\n") -> None:
73+
original = self.get_text()
74+
self.set_text(f"{original}{sep}{text}")
75+
76+
def get_prompt(self, default=None):
77+
if self.prompt is not None:
78+
return self.get_text()
79+
return default
80+
81+
82+
class LegacyCompletionTokenDetails(pydantic.BaseModel):
83+
accepted_prediction_tokens: int
84+
audio_tokens: int
85+
reasoning_tokens: int
86+
87+
88+
class LegacyPromptTokenDetails(pydantic.BaseModel):
89+
audio_tokens: int
90+
cached_tokens: int
91+
92+
93+
class LegacyUsage(pydantic.BaseModel):
94+
completion_tokens: int
95+
prompt_tokens: int
96+
total_tokens: int
97+
completion_tokens_details: LegacyCompletionTokenDetails | None = None
98+
prompt_tokens_details: LegacyPromptTokenDetails | None = None
99+
100+
101+
class LegacyLogProbs(pydantic.BaseModel):
102+
text_offset: List[Any]
103+
token_logprobs: List[Any]
104+
tokens: List[Any]
105+
top_logprobs: List[Any]
106+
107+
108+
class LegacyMessage(pydantic.BaseModel):
109+
text: str
110+
finish_reason: str | None = None
111+
index: int = 0
112+
logprobs: LegacyLogProbs | None = None
113+
114+
def get_text(self) -> str | None:
115+
return self.text
116+
117+
def set_text(self, text) -> None:
118+
self.text = text
119+
120+
121+
class LegacyCompletion(pydantic.BaseModel):
122+
id: str
123+
choices: List[LegacyMessage]
124+
created: int
125+
model: str
126+
system_fingerprint: str | None = None
127+
object: Literal["text_completion"] = "text_completion"
128+
usage: Usage | None = None
129+
130+
def get_content(self) -> Iterable[LegacyMessage]:
131+
for message in self.choices:
132+
yield message
133+
134+
def set_text(self, text) -> None:
135+
if self.choices:
136+
self.choices[0].set_text(text)

0 commit comments

Comments
 (0)