Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add integration tests with muxing #1035

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class BodyAdapter:

def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str:
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
if model_route.endpoint.provider_type == db_models.ProviderType.openai:
if model_route.endpoint.provider_type in [
db_models.ProviderType.openai,
db_models.ProviderType.vllm,
]:
return urljoin(model_route.endpoint.endpoint, "/v1")
if model_route.endpoint.provider_type == db_models.ProviderType.openrouter:
return urljoin(model_route.endpoint.endpoint, "/api/v1")
Expand Down Expand Up @@ -90,6 +93,47 @@ def _format_openai(self, chunk: str) -> str:
cleaned_chunk = chunk.split("data:")[1].strip()
return cleaned_chunk

def _format_antropic(self, chunk: str) -> str:
"""
Format the Anthropic chunk to OpenAI format.

This function is used by both chat and FIM formatters
"""
cleaned_chunk = chunk.split("data:")[1].strip()
try:
chunk_dict = json.loads(cleaned_chunk)
msg_type = chunk_dict.get("type", "")

finish_reason = None
if msg_type == "message_stop":
finish_reason = "stop"

# In type == "content_block_start" the content comes in "content_block"
# In type == "content_block_delta" the content comes in "delta"
msg_content_dict = chunk_dict.get("delta", {}) or chunk_dict.get("content_block", {})
# We couldn't obtain the content from the chunk. Skip it.
if not msg_content_dict:
return ""

msg_content = msg_content_dict.get("text", "")
open_ai_chunk = ModelResponse(
id=f"anthropic-chat-{str(uuid.uuid4())}",
model="anthropic-muxed-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(content=msg_content, role="assistant"),
logprobs=None,
)
],
)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception as e:
logger.warning(f"Error formatting Anthropic chunk: {chunk}. Error: {e}")
return cleaned_chunk.strip()

def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
chunk_to_send = f"data:{formatted_chunk}\n\n"
Expand Down Expand Up @@ -148,6 +192,8 @@ def provider_format_funcs(self) -> Dict[str, Callable]:
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
# VLLM is a dialect of OpenAI
db_models.ProviderType.vllm: self._format_openai,
}

def _format_ollama(self, chunk: str) -> str:
Expand All @@ -165,43 +211,6 @@ def _format_ollama(self, chunk: str) -> str:
logger.warning(f"Error formatting Ollama chunk: {chunk}. Error: {e}")
return chunk

def _format_antropic(self, chunk: str) -> str:
"""Format the Anthropic chunk to OpenAI format."""
cleaned_chunk = chunk.split("data:")[1].strip()
try:
chunk_dict = json.loads(cleaned_chunk)
msg_type = chunk_dict.get("type", "")

finish_reason = None
if msg_type == "message_stop":
finish_reason = "stop"

# In type == "content_block_start" the content comes in "content_block"
# In type == "content_block_delta" the content comes in "delta"
msg_content_dict = chunk_dict.get("delta", {}) or chunk_dict.get("content_block", {})
# We couldn't obtain the content from the chunk. Skip it.
if not msg_content_dict:
return ""

msg_content = msg_content_dict.get("text", "")
open_ai_chunk = ModelResponse(
id=f"anthropic-chat-{str(uuid.uuid4())}",
model="anthropic-muxed-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(content=msg_content, role="assistant"),
logprobs=None,
)
],
)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception as e:
logger.warning(f"Error formatting Anthropic chunk: {chunk}. Error: {e}")
return cleaned_chunk.strip()


class FimStreamChunkFormatter(StreamChunkFormatter):

Expand All @@ -218,6 +227,9 @@ def provider_format_funcs(self) -> Dict[str, Callable]:
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
# VLLM is a dialect of OpenAI
db_models.ProviderType.vllm: self._format_openai,
db_models.ProviderType.anthropic: self._format_antropic,
}

def _format_ollama(self, chunk: str) -> str:
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/anthropic/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@ headers:
anthropic:
x-api-key: ENV_ANTHROPIC_KEY

muxing:
mux_url: http://127.0.0.1:8989/v1/mux/
trimm_from_testcase_url: http://127.0.0.1:8989/anthropic/
provider_endpoint:
url: http://127.0.0.1:8989/api/v1/provider-endpoints
headers:
Content-Type: application/json
data: |
{
"name": "anthropic_muxing",
"description": "Muxing testing endpoint",
"provider_type": "anthropic",
"endpoint": "https://api.anthropic.com/",
"auth_type": "api_key",
"api_key": "ENV_ANTHROPIC_KEY"
}
muxes:
url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes
headers:
Content-Type: application/json
rules:
- model: claude-3-5-haiku-20241022
matcher_type: catch_all
matcher: ""

testcases:
anthropic_chat:
name: Anthropic Chat
Expand Down
99 changes: 89 additions & 10 deletions tests/integration/integration_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import copy
import json
import os
import re
import sys
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import requests
import structlog
Expand All @@ -21,7 +22,7 @@ def __init__(self):
self.failed_tests = [] # Track failed tests

def call_codegate(
self, url: str, headers: dict, data: dict, provider: str
self, url: str, headers: dict, data: dict, provider: str, method: str = "POST"
) -> Optional[requests.Response]:
logger.debug(f"Creating requester for provider: {provider}")
requester = self.requester_factory.create_requester(provider)
Expand All @@ -31,12 +32,12 @@ def call_codegate(
logger.debug(f"Headers: {headers}")
logger.debug(f"Data: {data}")

response = requester.make_request(url, headers, data)
response = requester.make_request(url, headers, data, method=method)

# Enhanced response logging
if response is not None:

if response.status_code != 200:
if response.status_code not in [200, 201, 204]:
logger.debug(f"Response error status: {response.status_code}")
logger.debug(f"Response error headers: {dict(response.headers)}")
try:
Expand Down Expand Up @@ -174,7 +175,7 @@ async def run_test(self, test: dict, test_headers: dict) -> bool:

async def _get_testcases(
self, testcases_dict: Dict, test_names: Optional[list[str]] = None
) -> Dict:
) -> Dict[str, Dict[str, str]]:
testcases: Dict[str, Dict[str, str]] = testcases_dict["testcases"]

# Filter testcases by provider and test names
Expand All @@ -192,24 +193,102 @@ async def _get_testcases(
testcases = filtered_testcases
return testcases

async def _setup_muxing(
self, provider: str, muxing_config: Optional[Dict]
) -> Optional[Tuple[str, str]]:
"""
Muxing setup. Create the provider endpoints and the muxing rules

Return
"""
# The muxing section was not found in the testcases.yaml file. Nothing to do.
if not muxing_config:
return

# Create the provider endpoint
provider_endpoint = muxing_config.get("provider_endpoint")
try:
data_with_api_keys = self.replace_env_variables(provider_endpoint["data"], os.environ)
response_create_provider = self.call_codegate(
provider=provider,
url=provider_endpoint["url"],
headers=provider_endpoint["headers"],
data=json.loads(data_with_api_keys),
)
created_provider_endpoint = response_create_provider.json()
except Exception as e:
logger.warning(f"Could not setup provider endpoint for muxing: {e}")
return
logger.info("Created provider endpoint for muixing")

muxes_rules: Dict[str, Any] = muxing_config.get("muxes", {})
try:
# We need to first update all the muxes with the provider_id
for mux in muxes_rules.get("rules", []):
mux["provider_id"] = created_provider_endpoint["id"]

# The endpoint actually takes a list
self.call_codegate(
provider=provider,
url=muxes_rules["url"],
headers=muxes_rules["headers"],
data=muxes_rules.get("rules", []),
method="PUT",
)
except Exception as e:
logger.warning(f"Could not setup muxing rules: {e}")
return
logger.info("Created muxing rules")

return muxing_config["mux_url"], muxing_config["trimm_from_testcase_url"]

async def _augment_testcases_with_muxing(
self, testcases: Dict, mux_url: str, trimm_from_testcase_url: str
) -> Dict:
"""
Augment the testcases with the muxing information. Copy the testcases
and execute them through the muxing endpoint.
"""
test_cases_with_muxing = copy.deepcopy(testcases)
for test_id, test_data in testcases.items():
# Replace the provider in the URL with the muxed URL
rest_of_path = test_data["url"].replace(trimm_from_testcase_url, "")
new_url = f"{mux_url}{rest_of_path}"
new_test_data = copy.deepcopy(test_data)
new_test_data["url"] = new_url
new_test_id = f"{test_id}_muxed"
test_cases_with_muxing[new_test_id] = new_test_data

logger.info("Augmented testcases with muxing")
return test_cases_with_muxing

async def _setup(
self, testcases_file: str, test_names: Optional[list[str]] = None
self, testcases_file: str, provider: str, test_names: Optional[list[str]] = None
) -> Tuple[Dict, Dict]:
with open(testcases_file, "r") as f:
testcases_dict = yaml.safe_load(f)
testcases_dict: Dict = yaml.safe_load(f)

headers = testcases_dict["headers"]
testcases = await self._get_testcases(testcases_dict, test_names)
return headers, testcases
muxing_result = await self._setup_muxing(provider, testcases_dict.get("muxing", {}))
# We don't have any muxing setup, return the headers and testcases
if not muxing_result:
return headers, testcases

mux_url, trimm_from_testcase_url = muxing_result
test_cases_with_muxing = await self._augment_testcases_with_muxing(
testcases, mux_url, trimm_from_testcase_url
)

return headers, test_cases_with_muxing

async def run_tests(
self,
testcases_file: str,
provider: str,
test_names: Optional[list[str]] = None,
) -> bool:
headers, testcases = await self._setup(testcases_file, test_names)

headers, testcases = await self._setup(testcases_file, provider, test_names)
if not testcases:
logger.warning(
f"No tests found for provider {provider} in file: {testcases_file} "
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/llamacpp/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@ headers:
llamacpp:
Content-Type: application/json

muxing:
mux_url: http://127.0.0.1:8989/v1/mux/
trimm_from_testcase_url: http://127.0.0.1:8989/llamacpp/
provider_endpoint:
url: http://127.0.0.1:8989/api/v1/provider-endpoints
headers:
Content-Type: application/json
data: |
{
"name": "llamacpp_muxing",
"description": "Muxing testing endpoint",
"provider_type": "llamacpp",
"endpoint": "./codegate_volume/models",
"auth_type": "none"
}
muxes:
url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes
headers:
Content-Type: application/json
rules:
- model: qwen2.5-coder-0.5b-instruct-q5_k_m
matcher_type: catch_all
matcher: ""

testcases:
llamacpp_chat:
name: LlamaCPP Chat
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/ollama/testcases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@ headers:
ollama:
Content-Type: application/json

muxing:
mux_url: http://127.0.0.1:8989/v1/mux/
trimm_from_testcase_url: http://127.0.0.1:8989/ollama/
provider_endpoint:
url: http://127.0.0.1:8989/api/v1/provider-endpoints
headers:
Content-Type: application/json
data: |
{
"name": "ollama_muxing",
"description": "Muxing testing endpoint",
"provider_type": "ollama",
"endpoint": "http://127.0.0.1:11434",
"auth_type": "none"
}
muxes:
url: http://127.0.0.1:8989/api/v1/workspaces/default/muxes
headers:
Content-Type: application/json
rules:
- model: qwen2.5-coder:1.5b
matcher_type: catch_all
matcher: ""

testcases:
ollama_chat:
name: Ollama Chat
Expand Down
Loading