Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Fix linting errors in CI #89

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
884 changes: 878 additions & 6 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from utils.embedding_util import generate_embeddings

import weaviate
from weaviate.classes.config import DataType, Property
from weaviate.embedded import EmbeddedOptions
from weaviate.classes.config import Property, DataType

from utils.embedding_util import generate_embeddings

json_files = [
"data/archived.jsonl",
Expand Down
52 changes: 28 additions & 24 deletions tests/providers/anthropic/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,35 @@
def adapter():
return AnthropicAdapter()


def test_translate_completion_input_params(adapter):
# Test input data
completion_request = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"stream": True,
"messages": [
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [
{
"type": "text",
"text": "Review this code"
}
]
}
]
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [{"type": "text", "text": "Review this code"}],
}
],
}
expected = {
'max_tokens': 1024,
'messages': [
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
"max_tokens": 1024,
"messages": [
{"content": [{"text": "Review this code", "type": "text"}], "role": "user"}
],
'model': 'claude-3-haiku-20240307',
'stream': True
"model": "claude-3-haiku-20240307",
"stream": True,
}

# Get translation
result = adapter.translate_completion_input_params(completion_request)
assert result == expected


@pytest.mark.asyncio
async def test_translate_completion_output_params_streaming(adapter):
# Test stream data
Expand All @@ -62,33 +59,38 @@ async def mock_stream():
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="Hello", role="assistant")),
delta=Delta(content="Hello", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant")),
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant")),
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant"),
),
],
model="claude-3-haiku-20240307",
),
]
for msg in messages:
yield msg

expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
expected: List[Union[MessageStartBlock, ContentBlockStart, ContentBlockDelta]] = [
MessageStartBlock(
type="message_start",
message=MessageChunk(
Expand Down Expand Up @@ -142,8 +144,10 @@ async def mock_stream():
def test_stream_generator_initialization(adapter):
# Verify the default stream generator is set
from codegate.providers.litellmshim import anthropic_stream_generator

assert adapter.stream_generator == anthropic_stream_generator


def test_custom_stream_generator():
# Test that we can inject a custom stream generator
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:
Expand Down
6 changes: 4 additions & 2 deletions tests/providers/litellmshim/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_sse_stream_generator():
# Mock stream data
mock_chunks = [
ModelResponse(id="1", choices=[{"text": "Hello"}]),
ModelResponse(id="2", choices=[{"text": "World"}])
ModelResponse(id="2", choices=[{"text": "World"}]),
]

async def mock_stream():
Expand All @@ -33,13 +33,14 @@ async def mock_stream():
assert "World" in messages[1]
assert messages[-1] == "data: [DONE]\n\n"


@pytest.mark.asyncio
async def test_anthropic_stream_generator():
# Mock Anthropic-style chunks
mock_chunks = [
{"type": "message_start", "message": {"id": "1"}},
{"type": "content_block_start", "content_block": {"text": "Hello"}},
{"type": "content_block_stop", "content_block": {"text": "World"}}
{"type": "content_block_stop", "content_block": {"text": "World"}},
]

async def mock_stream():
Expand All @@ -58,6 +59,7 @@ async def mock_stream():
assert "Hello" in messages[1] # content_block_start message
assert "World" in messages[2] # content_block_stop message


@pytest.mark.asyncio
async def test_generators_error_handling():
async def error_stream() -> AsyncIterator[str]:
Expand Down
18 changes: 13 additions & 5 deletions tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any:
return response

def translate_completion_output_params_streaming(
self, completion_stream: Any,
self,
completion_stream: Any,
) -> Any:
async def modified_stream():
async for chunk in completion_stream:
chunk.mock_adapter_processed = True
yield chunk

return modified_stream()


@pytest.fixture
def mock_adapter():
return MockAdapter()


@pytest.fixture
def litellm_shim(mock_adapter):
return LiteLLmShim(mock_adapter)


@pytest.mark.asyncio
async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Mock response
Expand All @@ -55,7 +60,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Test data
data = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo"
"model": "gpt-3.5-turbo",
}
api_key = "test-key"

Expand All @@ -71,6 +76,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
# Verify adapter processed the input
assert called_args["mock_adapter_processed"] is True


@pytest.mark.asyncio
async def test_complete_streaming():
# Mock streaming response with specific test content
Expand All @@ -86,7 +92,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
data = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
"stream": True
"stream": True,
}
api_key = "test-key"

Expand Down Expand Up @@ -114,6 +120,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
assert called_args["stream"] is True
assert called_args["api_key"] == api_key


@pytest.mark.asyncio
async def test_create_streaming_response(litellm_shim):
# Create a simple async generator that we know works
Expand All @@ -133,6 +140,7 @@ async def mock_stream_gen():
assert response.headers["Connection"] == "keep-alive"
assert response.headers["Transfer-Encoding"] == "chunked"


@pytest.mark.asyncio
async def test_complete_invalid_params():
mock_completion = AsyncMock()
Expand All @@ -148,8 +156,8 @@ async def test_complete_invalid_params():

# Execute and verify specific exception is raised
with pytest.raises(
ValueError,
match="Required fields 'messages' and 'model' must be present",
ValueError,
match="Required fields 'messages' and 'model' must be present",
):
await litellm_shim.complete(data, api_key)

Expand Down
10 changes: 9 additions & 1 deletion tests/providers/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,50 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
yield "test"

def create_streaming_response(
self, stream: AsyncIterator[Any],
self,
stream: AsyncIterator[Any],
) -> StreamingResponse:
return StreamingResponse(stream)


class MockProvider(BaseProvider):
def _setup_routes(self) -> None:
@self.router.get("/test")
def test_route():
return {"message": "test"}


@pytest.fixture
def mock_completion_handler():
return MockCompletionHandler()


@pytest.fixture
def app():
return FastAPI()


@pytest.fixture
def registry(app):
return ProviderRegistry(app)


def test_add_provider(registry, mock_completion_handler):
provider = MockProvider(mock_completion_handler)
registry.add_provider("test", provider)

assert "test" in registry.providers
assert registry.providers["test"] == provider


def test_get_provider(registry, mock_completion_handler):
provider = MockProvider(mock_completion_handler)
registry.add_provider("test", provider)

assert registry.get_provider("test") == provider
assert registry.get_provider("nonexistent") is None


def test_provider_routes_added(app, registry, mock_completion_handler):
provider = MockProvider(mock_completion_handler)
registry.add_provider("test", provider)
Expand Down
27 changes: 13 additions & 14 deletions utils/embedding_util.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as f
from torch import Tensor
import os
import warnings

import torch
import torch.nn.functional as ftorch
from torch import Tensor
from transformers import AutoModel, AutoTokenizer

# The transformers library internally is creating this warning, but does not
# impact our app. Safe to ignore.
warnings.filterwarnings(action='ignore', category=ResourceWarning)
warnings.filterwarnings(action="ignore", category=ResourceWarning)


# We won't have competing threads in this example app
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Initialize tokenizer and model for GTE-base
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
model = AutoModel.from_pretrained('thenlper/gte-base')
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")
model = AutoModel.from_pretrained("thenlper/gte-base")


def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def generate_embeddings(text):
inputs = tokenizer(text, return_tensors='pt',
max_length=512, truncation=True)
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model(**inputs)

attention_mask = inputs['attention_mask']
attention_mask = inputs["attention_mask"]
embeddings = average_pool(outputs.last_hidden_state, attention_mask)

# (Optionally) normalize embeddings
embeddings = f.normalize(embeddings, p=2, dim=1)
embeddings = ftorch.normalize(embeddings, p=2, dim=1)

return embeddings.numpy().tolist()[0]
return embeddings.numpy().tolist()[0]