Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1b711ff

Browse files
authored
Keep the code coverage high (#80)
We still need to add unit tests for OpenAI, will add them in a separate patch.
1 parent e48b1c5 commit 1b711ff

File tree

7 files changed

+457
-3
lines changed

7 files changed

+457
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ bandit = ">=1.7.10"
2323
build = ">=1.0.0"
2424
wheel = ">=0.40.0"
2525
litellm = ">=1.52.11"
26+
pytest-asyncio = "0.24.0"
2627

2728
[build-system]
2829
requires = ["poetry-core"]

src/codegate/providers/litellmshim/generators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
from typing import Any, AsyncIterator
33

4+
from pydantic import BaseModel
5+
46
# Since different providers typically use one of these formats for streaming
57
# responses, we have a single stream generator for each format that is then plugged
68
# into the adapter.
@@ -10,7 +12,9 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]
1012
"""OpenAI-style SSE format"""
1113
try:
1214
async for chunk in stream:
13-
if hasattr(chunk, "model_dump_json"):
15+
if isinstance(chunk, BaseModel):
16+
# alternatively we might want to just dump the whole object
17+
# this might even allow us to tighten the typing of the stream
1418
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
1519
try:
1620
yield f"data:{chunk}\n\n"

src/codegate/providers/litellmshim/litellmshim.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ class LiteLLmShim(BaseCompletionHandler):
1414
LiteLLM API.
1515
"""
1616

17-
def __init__(self, adapter: BaseAdapter):
17+
def __init__(self, adapter: BaseAdapter, completion_func=acompletion):
1818
self._adapter = adapter
19+
self._completion_func = completion_func
1920

2021
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
2122
"""
@@ -28,7 +29,7 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
2829
if completion_request is None:
2930
raise Exception("Couldn't translate the request")
3031

31-
response = await acompletion(**completion_request)
32+
response = await self._completion_func(**completion_request)
3233

3334
if isinstance(response, ModelResponse):
3435
return self._adapter.translate_completion_output_params(response)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing import AsyncIterator, Dict, List, Union
2+
3+
import pytest
4+
from litellm import ModelResponse
5+
from litellm.adapters.anthropic_adapter import AnthropicStreamWrapper
6+
from litellm.types.llms.anthropic import (
7+
ContentBlockDelta,
8+
ContentBlockStart,
9+
ContentTextBlockDelta,
10+
MessageChunk,
11+
MessageStartBlock,
12+
)
13+
from litellm.types.utils import Delta, StreamingChoices
14+
15+
from codegate.providers.anthropic.adapter import AnthropicAdapter
16+
17+
18+
@pytest.fixture
19+
def adapter():
20+
return AnthropicAdapter()
21+
22+
def test_translate_completion_input_params(adapter):
23+
# Test input data
24+
completion_request = {
25+
"model": "claude-3-haiku-20240307",
26+
"max_tokens": 1024,
27+
"stream": True,
28+
"messages": [
29+
{
30+
"role": "user",
31+
"system": "You are an expert code reviewer",
32+
"content": [
33+
{
34+
"type": "text",
35+
"text": "Review this code"
36+
}
37+
]
38+
}
39+
]
40+
}
41+
expected = {
42+
'max_tokens': 1024,
43+
'messages': [
44+
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
45+
],
46+
'model': 'claude-3-haiku-20240307',
47+
'stream': True
48+
}
49+
50+
# Get translation
51+
result = adapter.translate_completion_input_params(completion_request)
52+
assert result == expected
53+
54+
@pytest.mark.asyncio
55+
async def test_translate_completion_output_params_streaming(adapter):
56+
# Test stream data
57+
async def mock_stream():
58+
messages = [
59+
ModelResponse(
60+
id="test_id_1",
61+
choices=[
62+
StreamingChoices(
63+
finish_reason=None,
64+
index=0,
65+
delta=Delta(content="Hello", role="assistant")),
66+
],
67+
model="claude-3-haiku-20240307",
68+
),
69+
ModelResponse(
70+
id="test_id_2",
71+
choices=[
72+
StreamingChoices(finish_reason=None,
73+
index=0,
74+
delta=Delta(content="world", role="assistant")),
75+
],
76+
model="claude-3-haiku-20240307",
77+
),
78+
ModelResponse(
79+
id="test_id_2",
80+
choices=[
81+
StreamingChoices(finish_reason=None,
82+
index=0,
83+
delta=Delta(content="!", role="assistant")),
84+
],
85+
model="claude-3-haiku-20240307",
86+
),
87+
]
88+
for msg in messages:
89+
yield msg
90+
91+
expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
92+
MessageStartBlock(
93+
type="message_start",
94+
message=MessageChunk(
95+
id="msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
96+
type="message",
97+
role="assistant",
98+
content=[],
99+
# litellm makes up a message start block with hardcoded values
100+
model="claude-3-5-sonnet-20240620",
101+
stop_reason=None,
102+
stop_sequence=None,
103+
usage={"input_tokens": 25, "output_tokens": 1},
104+
),
105+
),
106+
ContentBlockStart(
107+
type="content_block_start",
108+
index=0,
109+
content_block={"type": "text", "text": ""},
110+
),
111+
ContentBlockDelta(
112+
type="content_block_delta",
113+
index=0,
114+
delta=ContentTextBlockDelta(type="text_delta", text="Hello"),
115+
),
116+
ContentBlockDelta(
117+
type="content_block_delta",
118+
index=0,
119+
delta=ContentTextBlockDelta(type="text_delta", text="world"),
120+
),
121+
ContentBlockDelta(
122+
type="content_block_delta",
123+
index=0,
124+
delta=ContentTextBlockDelta(type="text_delta", text="!"),
125+
),
126+
# litellm doesn't seem to have a type for message stop
127+
dict(type="message_stop"),
128+
]
129+
130+
stream = adapter.translate_completion_output_params_streaming(mock_stream())
131+
assert isinstance(stream, AnthropicStreamWrapper)
132+
133+
# just so that we can zip over the expected chunks
134+
stream_list = [chunk async for chunk in stream]
135+
# Verify we got all chunks
136+
assert len(stream_list) == 6
137+
138+
for chunk, expected_chunk in zip(stream_list, expected):
139+
assert chunk == expected_chunk
140+
141+
142+
def test_stream_generator_initialization(adapter):
143+
# Verify the default stream generator is set
144+
from codegate.providers.litellmshim import anthropic_stream_generator
145+
assert adapter.stream_generator == anthropic_stream_generator
146+
147+
def test_custom_stream_generator():
148+
# Test that we can inject a custom stream generator
149+
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:
150+
async for chunk in stream:
151+
yield "custom: " + str(chunk)
152+
153+
adapter = AnthropicAdapter(stream_generator=custom_generator)
154+
assert adapter.stream_generator == custom_generator
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import AsyncIterator
2+
3+
import pytest
4+
from litellm import ModelResponse
5+
6+
from codegate.providers.litellmshim import (
7+
anthropic_stream_generator,
8+
sse_stream_generator,
9+
)
10+
11+
12+
@pytest.mark.asyncio
13+
async def test_sse_stream_generator():
14+
# Mock stream data
15+
mock_chunks = [
16+
ModelResponse(id="1", choices=[{"text": "Hello"}]),
17+
ModelResponse(id="2", choices=[{"text": "World"}])
18+
]
19+
20+
async def mock_stream():
21+
for chunk in mock_chunks:
22+
yield chunk
23+
24+
# Collect generated SSE messages
25+
messages = []
26+
async for message in sse_stream_generator(mock_stream()):
27+
messages.append(message)
28+
29+
# Verify format and content
30+
assert len(messages) == len(mock_chunks) + 1 # +1 for the [DONE] message
31+
assert all(msg.startswith("data:") for msg in messages)
32+
assert "Hello" in messages[0]
33+
assert "World" in messages[1]
34+
assert messages[-1] == "data: [DONE]\n\n"
35+
36+
@pytest.mark.asyncio
37+
async def test_anthropic_stream_generator():
38+
# Mock Anthropic-style chunks
39+
mock_chunks = [
40+
{"type": "message_start", "message": {"id": "1"}},
41+
{"type": "content_block_start", "content_block": {"text": "Hello"}},
42+
{"type": "content_block_stop", "content_block": {"text": "World"}}
43+
]
44+
45+
async def mock_stream():
46+
for chunk in mock_chunks:
47+
yield chunk
48+
49+
# Collect generated SSE messages
50+
messages = []
51+
async for message in anthropic_stream_generator(mock_stream()):
52+
messages.append(message)
53+
54+
# Verify format and content
55+
assert len(messages) == 3
56+
for msg, chunk in zip(messages, mock_chunks):
57+
assert msg.startswith(f"event: {chunk['type']}\ndata:")
58+
assert "Hello" in messages[1] # content_block_start message
59+
assert "World" in messages[2] # content_block_stop message
60+
61+
@pytest.mark.asyncio
62+
async def test_generators_error_handling():
63+
async def error_stream() -> AsyncIterator[str]:
64+
raise Exception("Test error")
65+
yield # This will never be reached, but is needed for AsyncIterator typing
66+
67+
# Test SSE generator error handling
68+
messages = []
69+
async for message in sse_stream_generator(error_stream()):
70+
messages.append(message)
71+
assert len(messages) == 2
72+
assert "Test error" in messages[0]
73+
assert messages[1] == "data: [DONE]\n\n"
74+
75+
# Test Anthropic generator error handling
76+
messages = []
77+
async for message in anthropic_stream_generator(error_stream()):
78+
messages.append(message)
79+
assert len(messages) == 1
80+
assert "Test error" in messages[0]

0 commit comments

Comments
 (0)