Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Normalize and denormalize llamacpp streaming reply #121

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/codegate/pipeline/codegate_system_prompt/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class CodegateSystemPrompt(PipelineStep):

def __init__(self, system_prompt_message: Optional[str] = None):
self._system_message = ChatCompletionSystemMessage(
content=system_prompt_message,
role="system"
content=system_prompt_message, role="system"
)

@property
Expand All @@ -29,7 +28,7 @@ def name(self) -> str:
return "codegate-system-prompt"

async def process(
self, request: ChatCompletionRequest, context: PipelineContext
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Process the completion request and add a system prompt if the user message contains
Expand Down
23 changes: 21 additions & 2 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ def _setup_routes(self) -> None:
def provider_route_name(self) -> str:
pass

async def _run_output_stream_pipeline(
self,
normalized_stream: AsyncIterator[ModelResponse],
) -> AsyncIterator[ModelResponse]:
# we don't have a pipeline for output stream yet
return normalized_stream

def _run_output_pipeline(
self,
normalized_response: ModelResponse,
) -> ModelResponse:
# we don't have a pipeline for output yet
return normalized_response

async def _run_input_pipeline(
self, normalized_request: ChatCompletionRequest, is_fim_request: bool
) -> PipelineResult:
Expand Down Expand Up @@ -149,8 +163,13 @@ async def complete(
provider_request, api_key=api_key, stream=streaming
)
if not streaming:
return self._output_normalizer.denormalize(model_response)
return self._output_normalizer.denormalize_streaming(model_response)
normalized_response = self._output_normalizer.normalize(model_response)
pipeline_output = self._run_output_pipeline(normalized_response)
return self._output_normalizer.denormalize(pipeline_output)

normalized_stream = self._output_normalizer.normalize_streaming(model_response)
pipeline_output_stream = await self._run_output_stream_pipeline(normalized_stream)
return self._output_normalizer.denormalize_streaming(pipeline_output_stream)

def get_routes(self) -> APIRouter:
return self.router
1 change: 0 additions & 1 deletion src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json
from typing import Any, AsyncIterator

Expand Down
32 changes: 24 additions & 8 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@

from fastapi.responses import StreamingResponse
from litellm import ChatCompletionRequest, ModelResponse
from llama_cpp.llama_types import (
CreateChatCompletionStreamResponse,
)

from codegate.config import Config
from codegate.inference.inference_engine import LlamaCppInferenceEngine
from codegate.providers.base import BaseCompletionHandler


async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]:
async def llamacpp_stream_generator(
stream: AsyncIterator[CreateChatCompletionStreamResponse],
) -> AsyncIterator[str]:
"""OpenAI-style SSE format"""
try:
for chunk in stream:
if hasattr(chunk, "model_dump_json"):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
async for chunk in stream:
chunk = json.dumps(chunk)
try:
yield f"data:{json.dumps(chunk)}\n\n"
await asyncio.sleep(0)
yield f"data:{chunk}\n\n"
except Exception as e:
yield f"data:{str(e)}\n\n"
except Exception as e:
Expand All @@ -27,6 +30,18 @@ async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]
yield "data: [DONE]\n\n"


async def convert_to_async_iterator(
sync_iterator: Iterator[CreateChatCompletionStreamResponse],
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
"""
Convert a synchronous iterator to an asynchronous iterator. This makes the logic easier
because both the pipeline and the completion handler can use async iterators.
"""
for item in sync_iterator:
yield item
await asyncio.sleep(0)


class LlamaCppCompletionHandler(BaseCompletionHandler):
def __init__(self):
self.inference_engine = LlamaCppInferenceEngine()
Expand All @@ -53,9 +68,10 @@ async def execute_completion(
Config.get_config().chat_model_n_gpu_layers,
**request,
)
return response

def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse:
return convert_to_async_iterator(response) if stream else response

def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
"""
Create a streaming response from a stream generator. The StreamingResponse
is the format that FastAPI expects for streaming responses.
Expand Down
102 changes: 95 additions & 7 deletions src/codegate/providers/llamacpp/normalizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union
from typing import Any, AsyncIterable, AsyncIterator, Dict, Union

from litellm import ChatCompletionRequest, ModelResponse
from litellm.types.utils import Delta, StreamingChoices
from llama_cpp.llama_types import (
ChatCompletionStreamResponseChoice,
ChatCompletionStreamResponseDelta,
ChatCompletionStreamResponseDeltaEmpty,
CreateChatCompletionStreamResponse,
)

from codegate.providers.normalizer import ModelInputNormalizer, ModelOutputNormalizer

Expand Down Expand Up @@ -32,16 +39,97 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
return data


class ModelToLlamaCpp(AsyncIterator[CreateChatCompletionStreamResponse]):
def __init__(self, normalized_reply: AsyncIterable[ModelResponse]):
self.normalized_reply = normalized_reply
self._aiter = normalized_reply.__aiter__()

def __aiter__(self):
return self

@staticmethod
def _create_delta(
choice_delta: Delta,
) -> Union[ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty]:
if not choice_delta:
return ChatCompletionStreamResponseDeltaEmpty()
return ChatCompletionStreamResponseDelta(
content=choice_delta.content,
role=choice_delta.role,
)

async def __anext__(self) -> CreateChatCompletionStreamResponse:
try:
chunk = await self._aiter.__anext__()
return CreateChatCompletionStreamResponse(
id=chunk["id"],
model=chunk["model"],
object="chat.completion.chunk",
created=chunk["created"],
choices=[
ChatCompletionStreamResponseChoice(
index=choice.index,
delta=self._create_delta(choice.delta),
finish_reason=choice.finish_reason,
logprobs=None,
)
for choice in chunk["choices"]
],
)
except StopAsyncIteration:
raise StopAsyncIteration


class LlamaCppToModel(AsyncIterator[ModelResponse]):
def __init__(self, normalized_reply: AsyncIterable[CreateChatCompletionStreamResponse]):
self.normalized_reply = normalized_reply
self._aiter = normalized_reply.__aiter__()

def __aiter__(self):
return self

@staticmethod
def _create_delta(
choice_delta: Union[
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
]
) -> Delta:
if not choice_delta: # Handles empty dict case
return Delta(content=None, role=None)
return Delta(content=choice_delta.get("content"), role=choice_delta.get("role"))

async def __anext__(self) -> ModelResponse:
try:
chunk = await self._aiter.__anext__()
return ModelResponse(
id=chunk["id"],
choices=[
StreamingChoices(
finish_reason=choice.get("finish_reason", None),
index=choice["index"],
delta=self._create_delta(choice.get("delta")),
logprobs=None,
)
for choice in chunk["choices"]
],
created=chunk["created"],
model=chunk["model"],
object=chunk["object"],
)
except StopAsyncIteration:
raise StopAsyncIteration


class LLamaCppOutputNormalizer(ModelOutputNormalizer):
def normalize_streaming(
self,
model_reply: Union[AsyncIterable[Any], Iterable[Any]],
) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]:
llamacpp_stream: AsyncIterable[CreateChatCompletionStreamResponse],
) -> AsyncIterator[ModelResponse]:
"""
Normalize the output stream. This is a pass-through for liteLLM output normalizer
as the liteLLM output is already in the normalized format.
"""
return model_reply
return LlamaCppToModel(llamacpp_stream)

def normalize(self, model_reply: Any) -> ModelResponse:
"""
Expand All @@ -59,10 +147,10 @@ def denormalize(self, normalized_reply: ModelResponse) -> Any:

def denormalize_streaming(
self,
normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]],
) -> Union[AsyncIterator[Any], Iterator[Any]]:
model_stream: AsyncIterable[ModelResponse],
) -> AsyncIterator[CreateChatCompletionStreamResponse]:
"""
Denormalize the output stream from the completion function to the format
expected by the client
"""
return normalized_reply
return ModelToLlamaCpp(model_stream)
2 changes: 0 additions & 2 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from codegate.config import Config
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
from codegate.pipeline.secrets.secrets import CodegateSecrets
from codegate.pipeline.secrets.signatures import CodegateSignatures
from codegate.pipeline.version.version import CodegateVersion
from codegate.providers.anthropic.provider import AnthropicProvider
from codegate.providers.llamacpp.provider import LlamaCppProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,61 +24,50 @@ def test_init_with_system_message(self):
step = CodegateSystemPrompt(system_prompt_message=test_message)
assert step._system_message["content"] == test_message

@pytest.mark.parametrize("user_message,expected_modification", [
# Test cases with different scenarios
("Hello CodeGate", True),
("CODEGATE in uppercase", True),
("No matching message", False),
("codegate with lowercase", True)
])
async def test_process_system_prompt_insertion(
self,
user_message,
expected_modification
):
@pytest.mark.parametrize(
"user_message,expected_modification",
[
# Test cases with different scenarios
("Hello CodeGate", True),
("CODEGATE in uppercase", True),
("No matching message", False),
("codegate with lowercase", True),
],
)
async def test_process_system_prompt_insertion(self, user_message, expected_modification):
"""
Test system prompt insertion based on message content
"""
# Prepare mock request with user message
mock_request = {
"messages": [
{"role": "user", "content": user_message}
]
}
mock_request = {"messages": [{"role": "user", "content": user_message}]}
mock_context = Mock(spec=PipelineContext)

# Create system prompt step
system_prompt = "Security analysis system prompt"
step = CodegateSystemPrompt(system_prompt_message=system_prompt)

# Mock the get_last_user_message method
step.get_last_user_message = Mock(
return_value=(user_message, 0)
)
step.get_last_user_message = Mock(return_value=(user_message, 0))

# Process the request
result = await step.process(ChatCompletionRequest(**mock_request), mock_context)

if expected_modification:
# Check that system message was inserted
assert len(result.request['messages']) == 2
assert result.request['messages'][0]['role'] == 'system'
assert result.request['messages'][0]['content'] == system_prompt
assert result.request['messages'][1]['role'] == 'user'
assert result.request['messages'][1]['content'] == user_message
assert len(result.request["messages"]) == 2
assert result.request["messages"][0]["role"] == "system"
assert result.request["messages"][0]["content"] == system_prompt
assert result.request["messages"][1]["role"] == "user"
assert result.request["messages"][1]["content"] == user_message
else:
# Ensure no modification occurred
assert len(result.request['messages']) == 1
assert len(result.request["messages"]) == 1

async def test_no_system_message_configured(self):
"""
Test behavior when no system message is configured
"""
mock_request = {
"messages": [
{"role": "user", "content": "CodeGate test"}
]
}
mock_request = {"messages": [{"role": "user", "content": "CodeGate test"}]}
mock_context = Mock(spec=PipelineContext)

# Create step without system message
Expand All @@ -90,10 +79,13 @@ async def test_no_system_message_configured(self):
# Verify request remains unchanged
assert result.request == mock_request

@pytest.mark.parametrize("edge_case", [
None, # No messages
[], # Empty messages list
])
@pytest.mark.parametrize(
"edge_case",
[
None, # No messages
[], # Empty messages list
],
)
async def test_edge_cases(self, edge_case):
"""
Test edge cases with None or empty message list
Expand Down
Loading