Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1d647ea

Browse files
committed
Normalize and denormalize llamacpp streaming reply
Originally, I wanted to add the normalizers to convert the `im_start`/`im_end` tags, but we worked around that by setting llamacpp to use the OpenAI format. We'll still need a normalizer for the vllm provider though. At the moment we really need the denormalizer so that the blocking pipeline can return a stream of `ModelResponse`s and the denormalizer would convert them to the CreateChatCompletionStreamResponse structure that is then serialized to the client. This avoids any guessing or special casing that would otherwise be needed in the `llamacpp_stream_generator` which currently expected `Iterator[CreateChatCompletionStreamResponse]`. Another change that simplifies the logic is that the `llamacpp_stream_generator` now accepts an `AsyncIterator` instead of just `Iterator` that the llamacpp completion hander was returning. Again, this is to simplify the logic and pass the iterator from the blocking pipeline. On the completion side we have a simple sync-to-async wrapper. Fixes: #94
1 parent 9954610 commit 1d647ea

File tree

4 files changed

+283
-17
lines changed

4 files changed

+283
-17
lines changed

src/codegate/providers/base.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ def _setup_routes(self) -> None:
4949
def provider_route_name(self) -> str:
5050
pass
5151

52+
async def _run_output_stream_pipeline(
53+
self,
54+
normalized_stream: AsyncIterator[ModelResponse],
55+
) -> AsyncIterator[ModelResponse]:
56+
# we don't have a pipeline for output stream yet
57+
return normalized_stream
58+
59+
def _run_output_pipeline(
60+
self,
61+
normalized_response: ModelResponse,
62+
) -> ModelResponse:
63+
# we don't have a pipeline for output yet
64+
return normalized_response
65+
5266
async def _run_input_pipeline(
5367
self, normalized_request: ChatCompletionRequest, is_fim_request: bool
5468
) -> PipelineResult:
@@ -149,8 +163,13 @@ async def complete(
149163
provider_request, api_key=api_key, stream=streaming
150164
)
151165
if not streaming:
152-
return self._output_normalizer.denormalize(model_response)
153-
return self._output_normalizer.denormalize_streaming(model_response)
166+
normalized_response = self._output_normalizer.normalize(model_response)
167+
pipeline_output = self._run_output_pipeline(normalized_response)
168+
return self._output_normalizer.denormalize(pipeline_output)
169+
170+
normalized_stream = self._output_normalizer.normalize_streaming(model_response)
171+
pipeline_output_stream = await self._run_output_stream_pipeline(normalized_stream)
172+
return self._output_normalizer.denormalize_streaming(pipeline_output_stream)
154173

155174
def get_routes(self) -> APIRouter:
156175
return self.router

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,41 @@
44

55
from fastapi.responses import StreamingResponse
66
from litellm import ChatCompletionRequest, ModelResponse
7+
from llama_cpp.llama_types import (
8+
CreateChatCompletionStreamResponse,
9+
)
710

811
from codegate.config import Config
912
from codegate.inference.inference_engine import LlamaCppInferenceEngine
1013
from codegate.providers.base import BaseCompletionHandler
1114

1215

13-
async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]:
16+
async def llamacpp_stream_generator(
17+
stream: AsyncIterator[CreateChatCompletionStreamResponse],
18+
) -> AsyncIterator[str]:
1419
"""OpenAI-style SSE format"""
1520
try:
16-
for chunk in stream:
17-
if hasattr(chunk, "model_dump_json"):
18-
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
21+
async for chunk in stream:
22+
chunk = json.dumps(chunk)
1923
try:
20-
yield f"data:{json.dumps(chunk)}\n\n"
21-
await asyncio.sleep(0)
24+
yield f"data:{chunk}\n\n"
2225
except Exception as e:
2326
yield f"data:{str(e)}\n\n"
2427
except Exception as e:
2528
yield f"data: {str(e)}\n\n"
2629
finally:
2730
yield "data: [DONE]\n\n"
2831

32+
async def convert_to_async_iterator(
33+
sync_iterator: Iterator[CreateChatCompletionStreamResponse],
34+
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
35+
"""
36+
Convert a synchronous iterator to an asynchronous iterator. This makes the logic easier
37+
because both the pipeline and the completion handler can use async iterators.
38+
"""
39+
for item in sync_iterator:
40+
yield item
41+
await asyncio.sleep(0)
2942

3043
class LlamaCppCompletionHandler(BaseCompletionHandler):
3144
def __init__(self):
@@ -53,9 +66,10 @@ async def execute_completion(
5366
Config.get_config().chat_model_n_gpu_layers,
5467
**request,
5568
)
56-
return response
5769

58-
def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse:
70+
return convert_to_async_iterator(response) if stream else response
71+
72+
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
5973
"""
6074
Create a streaming response from a stream generator. The StreamingResponse
6175
is the format that FastAPI expects for streaming responses.

src/codegate/providers/llamacpp/normalizer.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union
1+
from typing import Any, AsyncIterable, AsyncIterator, Dict, Union
22

33
from litellm import ChatCompletionRequest, ModelResponse
4+
from litellm.types.utils import Delta, StreamingChoices
5+
from llama_cpp.llama_types import (
6+
ChatCompletionStreamResponseChoice,
7+
ChatCompletionStreamResponseDelta,
8+
ChatCompletionStreamResponseDeltaEmpty,
9+
CreateChatCompletionStreamResponse,
10+
)
411

512
from codegate.providers.normalizer import ModelInputNormalizer, ModelOutputNormalizer
613

@@ -31,17 +38,96 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
3138
del data["messages"]
3239
return data
3340

41+
class ModelToLlamaCpp(AsyncIterator[CreateChatCompletionStreamResponse]):
42+
def __init__(self, normalized_reply: AsyncIterable[ModelResponse]):
43+
self.normalized_reply = normalized_reply
44+
self._aiter = normalized_reply.__aiter__()
45+
46+
def __aiter__(self):
47+
return self
48+
49+
@staticmethod
50+
def _create_delta(
51+
choice_delta: Delta) -> Union[
52+
ChatCompletionStreamResponseDelta,
53+
ChatCompletionStreamResponseDeltaEmpty
54+
]:
55+
if not choice_delta:
56+
return ChatCompletionStreamResponseDeltaEmpty()
57+
return ChatCompletionStreamResponseDelta(
58+
content=choice_delta.content,
59+
role=choice_delta.role,
60+
)
61+
62+
async def __anext__(self) -> CreateChatCompletionStreamResponse:
63+
try:
64+
chunk = await self._aiter.__anext__()
65+
return CreateChatCompletionStreamResponse(
66+
id=chunk['id'],
67+
model=chunk['model'],
68+
object='chat.completion.chunk',
69+
created=chunk['created'],
70+
choices=[ ChatCompletionStreamResponseChoice(
71+
index=choice.index,
72+
delta=self._create_delta(choice.delta),
73+
finish_reason=choice.finish_reason,
74+
logprobs=None,
75+
) for choice in chunk['choices'] ]
76+
)
77+
except StopAsyncIteration:
78+
raise StopAsyncIteration
79+
80+
class LlamaCppToModel(AsyncIterator[ModelResponse]):
81+
def __init__(self, normalized_reply: AsyncIterable[CreateChatCompletionStreamResponse]):
82+
self.normalized_reply = normalized_reply
83+
self._aiter = normalized_reply.__aiter__()
84+
85+
def __aiter__(self):
86+
return self
87+
88+
@staticmethod
89+
def _create_delta(
90+
choice_delta: Union[
91+
ChatCompletionStreamResponseDelta,
92+
ChatCompletionStreamResponseDeltaEmpty
93+
]) -> Delta:
94+
if not choice_delta: # Handles empty dict case
95+
return Delta(content=None, role=None)
96+
return Delta(
97+
content=choice_delta.get('content'),
98+
role=choice_delta.get('role')
99+
)
100+
101+
async def __anext__(self) -> ModelResponse:
102+
try:
103+
chunk = await self._aiter.__anext__()
104+
return ModelResponse(
105+
id=chunk["id"],
106+
choices=[
107+
StreamingChoices(
108+
finish_reason=choice.get("finish_reason", None),
109+
index=choice["index"],
110+
delta=self._create_delta(choice.get('delta')),
111+
logprobs=None,
112+
) for choice in chunk["choices"]
113+
],
114+
created=chunk["created"],
115+
model=chunk["model"],
116+
object=chunk["object"],
117+
)
118+
except StopAsyncIteration:
119+
raise StopAsyncIteration
34120

35121
class LLamaCppOutputNormalizer(ModelOutputNormalizer):
36122
def normalize_streaming(
37123
self,
38-
model_reply: Union[AsyncIterable[Any], Iterable[Any]],
39-
) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]:
124+
llamacpp_stream: AsyncIterable[CreateChatCompletionStreamResponse],
125+
) -> AsyncIterator[ModelResponse]:
40126
"""
41127
Normalize the output stream. This is a pass-through for liteLLM output normalizer
42128
as the liteLLM output is already in the normalized format.
43129
"""
44-
return model_reply
130+
return LlamaCppToModel(llamacpp_stream)
45131

46132
def normalize(self, model_reply: Any) -> ModelResponse:
47133
"""
@@ -59,10 +145,10 @@ def denormalize(self, normalized_reply: ModelResponse) -> Any:
59145

60146
def denormalize_streaming(
61147
self,
62-
normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]],
63-
) -> Union[AsyncIterator[Any], Iterator[Any]]:
148+
model_stream: AsyncIterable[ModelResponse],
149+
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
64150
"""
65151
Denormalize the output stream from the completion function to the format
66152
expected by the client
67153
"""
68-
return normalized_reply
154+
return ModelToLlamaCpp(model_stream)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import pytest
2+
from litellm import ModelResponse
3+
from litellm.types.utils import Delta, StreamingChoices
4+
from llama_cpp.llama_types import CreateChatCompletionStreamResponse
5+
6+
from codegate.providers.llamacpp.normalizer import (
7+
LLamaCppOutputNormalizer,
8+
)
9+
10+
11+
class TestLLamaCppStreamNormalizer:
12+
@pytest.mark.asyncio
13+
async def test_normalize_streaming(self):
14+
"""
15+
Test the normalize_streaming method
16+
Verify conversion from llama.cpp stream to ModelResponse stream
17+
"""
18+
# Mock CreateChatCompletionStreamResponse stream
19+
async def mock_llamacpp_stream():
20+
responses = [
21+
CreateChatCompletionStreamResponse(
22+
id="test_id1",
23+
model="llama-model",
24+
object="chat.completion.chunk",
25+
created=1234567,
26+
choices=[{
27+
"index": 0,
28+
"delta": {"content": "Hello"},
29+
"finish_reason": None
30+
}]
31+
),
32+
CreateChatCompletionStreamResponse(
33+
id="test_id2",
34+
model="llama-model",
35+
object="chat.completion.chunk",
36+
created=1234568,
37+
choices=[{
38+
"index": 0,
39+
"delta": {"content": " World"},
40+
"finish_reason": "stop"
41+
}]
42+
)
43+
]
44+
for resp in responses:
45+
yield resp
46+
47+
# Create normalizer and normalize stream
48+
normalizer = LLamaCppOutputNormalizer()
49+
normalized_stream = normalizer.normalize_streaming(mock_llamacpp_stream())
50+
51+
# Collect results
52+
results = []
53+
async for response in normalized_stream:
54+
results.append(response)
55+
56+
# Assertions
57+
assert len(results) == 2
58+
assert all(isinstance(r, ModelResponse) for r in results)
59+
60+
# Check first chunk
61+
assert results[0].choices[0].delta.content == "Hello"
62+
assert results[0].choices[0].finish_reason is None
63+
64+
# Check second chunk
65+
assert results[1].choices[0].delta.content == " World"
66+
assert results[1].choices[0].finish_reason == "stop"
67+
68+
@pytest.mark.asyncio
69+
async def test_denormalize_streaming(self):
70+
"""
71+
Test the denormalize_streaming method
72+
Verify conversion from ModelResponse stream to llama.cpp stream
73+
"""
74+
# Mock ModelResponse stream
75+
async def mock_model_response_stream():
76+
responses = [
77+
ModelResponse(
78+
id="test_id1",
79+
model="litellm-model",
80+
object="chat.completion",
81+
created=1234567,
82+
choices=[StreamingChoices(
83+
index=0,
84+
delta=Delta(content="Hello"),
85+
finish_reason=None
86+
)]
87+
),
88+
ModelResponse(
89+
id="test_id2",
90+
model="litellm-model",
91+
object="chat.completion",
92+
created=1234568,
93+
choices=[StreamingChoices(
94+
index=0,
95+
delta=Delta(content=" World"),
96+
finish_reason="stop"
97+
)]
98+
)
99+
]
100+
for resp in responses:
101+
yield resp
102+
103+
# Create normalizer and denormalize stream
104+
normalizer = LLamaCppOutputNormalizer()
105+
denormalized_stream = normalizer.denormalize_streaming(mock_model_response_stream())
106+
107+
# Collect results
108+
results = []
109+
async for response in denormalized_stream:
110+
results.append(response)
111+
112+
# Assertions
113+
assert len(results) == 2
114+
115+
# Check first chunk
116+
assert results[0]['choices'][0]['delta']['content'] == "Hello"
117+
assert results[0]['choices'][0]['finish_reason'] is None
118+
119+
# Check second chunk
120+
assert results[1]['choices'][0]['delta']['content'] == " World"
121+
assert results[1]['choices'][0]['finish_reason'] == "stop"
122+
123+
@pytest.mark.asyncio
124+
async def test_streaming_edge_cases(self):
125+
"""
126+
Test edge cases and error scenarios in streaming
127+
"""
128+
# Empty stream
129+
async def empty_stream():
130+
return
131+
yield
132+
133+
normalizer = LLamaCppOutputNormalizer()
134+
135+
# Test empty stream for normalize_streaming
136+
normalized_empty = normalizer.normalize_streaming(empty_stream())
137+
with pytest.raises(StopAsyncIteration):
138+
await normalized_empty.__anext__()
139+
140+
# Test empty stream for denormalize_streaming
141+
async def empty_model_stream():
142+
return
143+
yield
144+
145+
denormalized_empty = normalizer.denormalize_streaming(empty_model_stream())
146+
with pytest.raises(StopAsyncIteration):
147+
await denormalized_empty.__anext__()

0 commit comments

Comments
 (0)