Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit b975b57

Browse files
author
Luke Hinds
authored
Merge pull request #89 from stacklok/fix-ci
Fix linting errors in CI
2 parents fb68dc1 + 8e10b50 commit b975b57

File tree

7 files changed

+948
-54
lines changed

7 files changed

+948
-54
lines changed

poetry.lock

Lines changed: 878 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scripts/import_packages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
2-
from utils.embedding_util import generate_embeddings
2+
33
import weaviate
4+
from weaviate.classes.config import DataType, Property
45
from weaviate.embedded import EmbeddedOptions
5-
from weaviate.classes.config import Property, DataType
66

7+
from utils.embedding_util import generate_embeddings
78

89
json_files = [
910
"data/archived.jsonl",

tests/providers/anthropic/test_adapter.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,35 @@
1919
def adapter():
2020
return AnthropicAdapter()
2121

22+
2223
def test_translate_completion_input_params(adapter):
2324
# Test input data
2425
completion_request = {
2526
"model": "claude-3-haiku-20240307",
2627
"max_tokens": 1024,
2728
"stream": True,
2829
"messages": [
29-
{
30-
"role": "user",
31-
"system": "You are an expert code reviewer",
32-
"content": [
33-
{
34-
"type": "text",
35-
"text": "Review this code"
36-
}
37-
]
38-
}
39-
]
30+
{
31+
"role": "user",
32+
"system": "You are an expert code reviewer",
33+
"content": [{"type": "text", "text": "Review this code"}],
34+
}
35+
],
4036
}
4137
expected = {
42-
'max_tokens': 1024,
43-
'messages': [
44-
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
38+
"max_tokens": 1024,
39+
"messages": [
40+
{"content": [{"text": "Review this code", "type": "text"}], "role": "user"}
4541
],
46-
'model': 'claude-3-haiku-20240307',
47-
'stream': True
42+
"model": "claude-3-haiku-20240307",
43+
"stream": True,
4844
}
4945

5046
# Get translation
5147
result = adapter.translate_completion_input_params(completion_request)
5248
assert result == expected
5349

50+
5451
@pytest.mark.asyncio
5552
async def test_translate_completion_output_params_streaming(adapter):
5653
# Test stream data
@@ -62,33 +59,38 @@ async def mock_stream():
6259
StreamingChoices(
6360
finish_reason=None,
6461
index=0,
65-
delta=Delta(content="Hello", role="assistant")),
62+
delta=Delta(content="Hello", role="assistant"),
63+
),
6664
],
6765
model="claude-3-haiku-20240307",
6866
),
6967
ModelResponse(
7068
id="test_id_2",
7169
choices=[
72-
StreamingChoices(finish_reason=None,
73-
index=0,
74-
delta=Delta(content="world", role="assistant")),
70+
StreamingChoices(
71+
finish_reason=None,
72+
index=0,
73+
delta=Delta(content="world", role="assistant"),
74+
),
7575
],
7676
model="claude-3-haiku-20240307",
7777
),
7878
ModelResponse(
7979
id="test_id_2",
8080
choices=[
81-
StreamingChoices(finish_reason=None,
82-
index=0,
83-
delta=Delta(content="!", role="assistant")),
81+
StreamingChoices(
82+
finish_reason=None,
83+
index=0,
84+
delta=Delta(content="!", role="assistant"),
85+
),
8486
],
8587
model="claude-3-haiku-20240307",
8688
),
8789
]
8890
for msg in messages:
8991
yield msg
9092

91-
expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
93+
expected: List[Union[MessageStartBlock, ContentBlockStart, ContentBlockDelta]] = [
9294
MessageStartBlock(
9395
type="message_start",
9496
message=MessageChunk(
@@ -142,8 +144,10 @@ async def mock_stream():
142144
def test_stream_generator_initialization(adapter):
143145
# Verify the default stream generator is set
144146
from codegate.providers.litellmshim import anthropic_stream_generator
147+
145148
assert adapter.stream_generator == anthropic_stream_generator
146149

150+
147151
def test_custom_stream_generator():
148152
# Test that we can inject a custom stream generator
149153
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:

tests/providers/litellmshim/test_generators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def test_sse_stream_generator():
1414
# Mock stream data
1515
mock_chunks = [
1616
ModelResponse(id="1", choices=[{"text": "Hello"}]),
17-
ModelResponse(id="2", choices=[{"text": "World"}])
17+
ModelResponse(id="2", choices=[{"text": "World"}]),
1818
]
1919

2020
async def mock_stream():
@@ -33,13 +33,14 @@ async def mock_stream():
3333
assert "World" in messages[1]
3434
assert messages[-1] == "data: [DONE]\n\n"
3535

36+
3637
@pytest.mark.asyncio
3738
async def test_anthropic_stream_generator():
3839
# Mock Anthropic-style chunks
3940
mock_chunks = [
4041
{"type": "message_start", "message": {"id": "1"}},
4142
{"type": "content_block_start", "content_block": {"text": "Hello"}},
42-
{"type": "content_block_stop", "content_block": {"text": "World"}}
43+
{"type": "content_block_stop", "content_block": {"text": "World"}},
4344
]
4445

4546
async def mock_stream():
@@ -58,6 +59,7 @@ async def mock_stream():
5859
assert "Hello" in messages[1] # content_block_start message
5960
assert "World" in messages[2] # content_block_stop message
6061

62+
6163
@pytest.mark.asyncio
6264
async def test_generators_error_handling():
6365
async def error_stream() -> AsyncIterator[str]:

tests/providers/litellmshim/test_litellmshim.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,27 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any:
2727
return response
2828

2929
def translate_completion_output_params_streaming(
30-
self, completion_stream: Any,
30+
self,
31+
completion_stream: Any,
3132
) -> Any:
3233
async def modified_stream():
3334
async for chunk in completion_stream:
3435
chunk.mock_adapter_processed = True
3536
yield chunk
37+
3638
return modified_stream()
3739

40+
3841
@pytest.fixture
3942
def mock_adapter():
4043
return MockAdapter()
4144

45+
4246
@pytest.fixture
4347
def litellm_shim(mock_adapter):
4448
return LiteLLmShim(mock_adapter)
4549

50+
4651
@pytest.mark.asyncio
4752
async def test_complete_non_streaming(litellm_shim, mock_adapter):
4853
# Mock response
@@ -55,7 +60,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
5560
# Test data
5661
data = {
5762
"messages": [{"role": "user", "content": "Hello"}],
58-
"model": "gpt-3.5-turbo"
63+
"model": "gpt-3.5-turbo",
5964
}
6065
api_key = "test-key"
6166

@@ -71,6 +76,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter):
7176
# Verify adapter processed the input
7277
assert called_args["mock_adapter_processed"] is True
7378

79+
7480
@pytest.mark.asyncio
7581
async def test_complete_streaming():
7682
# Mock streaming response with specific test content
@@ -86,7 +92,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
8692
data = {
8793
"messages": [{"role": "user", "content": "Hello"}],
8894
"model": "gpt-3.5-turbo",
89-
"stream": True
95+
"stream": True,
9096
}
9197
api_key = "test-key"
9298

@@ -114,6 +120,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
114120
assert called_args["stream"] is True
115121
assert called_args["api_key"] == api_key
116122

123+
117124
@pytest.mark.asyncio
118125
async def test_create_streaming_response(litellm_shim):
119126
# Create a simple async generator that we know works
@@ -133,6 +140,7 @@ async def mock_stream_gen():
133140
assert response.headers["Connection"] == "keep-alive"
134141
assert response.headers["Transfer-Encoding"] == "chunked"
135142

143+
136144
@pytest.mark.asyncio
137145
async def test_complete_invalid_params():
138146
mock_completion = AsyncMock()
@@ -148,8 +156,8 @@ async def test_complete_invalid_params():
148156

149157
# Execute and verify specific exception is raised
150158
with pytest.raises(
151-
ValueError,
152-
match="Required fields 'messages' and 'model' must be present",
159+
ValueError,
160+
match="Required fields 'messages' and 'model' must be present",
153161
):
154162
await litellm_shim.complete(data, api_key)
155163

tests/providers/test_registry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,50 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
1313
yield "test"
1414

1515
def create_streaming_response(
16-
self, stream: AsyncIterator[Any],
16+
self,
17+
stream: AsyncIterator[Any],
1718
) -> StreamingResponse:
1819
return StreamingResponse(stream)
1920

21+
2022
class MockProvider(BaseProvider):
2123
def _setup_routes(self) -> None:
2224
@self.router.get("/test")
2325
def test_route():
2426
return {"message": "test"}
2527

28+
2629
@pytest.fixture
2730
def mock_completion_handler():
2831
return MockCompletionHandler()
2932

33+
3034
@pytest.fixture
3135
def app():
3236
return FastAPI()
3337

38+
3439
@pytest.fixture
3540
def registry(app):
3641
return ProviderRegistry(app)
3742

43+
3844
def test_add_provider(registry, mock_completion_handler):
3945
provider = MockProvider(mock_completion_handler)
4046
registry.add_provider("test", provider)
4147

4248
assert "test" in registry.providers
4349
assert registry.providers["test"] == provider
4450

51+
4552
def test_get_provider(registry, mock_completion_handler):
4653
provider = MockProvider(mock_completion_handler)
4754
registry.add_provider("test", provider)
4855

4956
assert registry.get_provider("test") == provider
5057
assert registry.get_provider("nonexistent") is None
5158

59+
5260
def test_provider_routes_added(app, registry, mock_completion_handler):
5361
provider = MockProvider(mock_completion_handler)
5462
registry.add_provider("test", provider)

utils/embedding_util.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,39 @@
1-
from transformers import AutoTokenizer, AutoModel
2-
import torch
3-
import torch.nn.functional as f
4-
from torch import Tensor
51
import os
62
import warnings
73

4+
import torch
5+
import torch.nn.functional as ftorch
6+
from torch import Tensor
7+
from transformers import AutoModel, AutoTokenizer
8+
89
# The transformers library internally is creating this warning, but does not
910
# impact our app. Safe to ignore.
10-
warnings.filterwarnings(action='ignore', category=ResourceWarning)
11+
warnings.filterwarnings(action="ignore", category=ResourceWarning)
1112

1213

1314
# We won't have competing threads in this example app
1415
os.environ["TOKENIZERS_PARALLELISM"] = "false"
1516

1617

1718
# Initialize tokenizer and model for GTE-base
18-
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
19-
model = AutoModel.from_pretrained('thenlper/gte-base')
19+
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")
20+
model = AutoModel.from_pretrained("thenlper/gte-base")
2021

2122

2223
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
23-
last_hidden = last_hidden_states.masked_fill(
24-
~attention_mask[..., None].bool(), 0.0)
24+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
2525
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
2626

2727

2828
def generate_embeddings(text):
29-
inputs = tokenizer(text, return_tensors='pt',
30-
max_length=512, truncation=True)
29+
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
3130
with torch.no_grad():
3231
outputs = model(**inputs)
3332

34-
attention_mask = inputs['attention_mask']
33+
attention_mask = inputs["attention_mask"]
3534
embeddings = average_pool(outputs.last_hidden_state, attention_mask)
3635

3736
# (Optionally) normalize embeddings
38-
embeddings = f.normalize(embeddings, p=2, dim=1)
37+
embeddings = ftorch.normalize(embeddings, p=2, dim=1)
3938

40-
return embeddings.numpy().tolist()[0]
39+
return embeddings.numpy().tolist()[0]

0 commit comments

Comments
 (0)