Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add normalizer instead of abusing LiteLLM adapter #106

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class PackageImporter:
def __init__(self):
self.client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data",
grpc_port=50052
persistence_data_path="./weaviate_data", grpc_port=50052
)
)
self.json_files = [
Expand Down Expand Up @@ -46,13 +45,13 @@ def generate_vector_string(self, package):
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package"
"java": "Java package",
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer "
"recommended for use.",
"malicious": "However, this package is found to be malicious."
"malicious": "However, this package is found to be malicious.",
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
Expand All @@ -75,8 +74,9 @@ async def add_data(self):
packages_dict = {
f"{package.properties['name']}/{package.properties['type']}": {
"status": package.properties["status"],
"description": package.properties["description"]
} for package in existing_packages
"description": package.properties["description"],
}
for package in existing_packages
}

for json_file in self.json_files:
Expand All @@ -85,12 +85,12 @@ async def add_data(self):
packages_to_insert = []
for line in f:
package = json.loads(line)
package["status"] = json_file.split('/')[-1].split('.')[0]
package["status"] = json_file.split("/")[-1].split(".")[0]
key = f"{package['name']}/{package['type']}"

if key in packages_dict and packages_dict[key] == {
"status": package["status"],
"description": package["description"]
"description": package["description"],
}:
print("Package already exists", key)
continue
Expand All @@ -102,8 +102,9 @@ async def add_data(self):
# Synchronous batch insert after preparing all data
with collection.batch.dynamic() as batch:
for package, vector in packages_to_insert:
batch.add_object(properties=package, vector=vector,
uuid=generate_uuid5(package))
batch.add_object(
properties=package, vector=vector, uuid=generate_uuid5(package)
)

async def run_import(self):
self.setup_schema()
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Codegate - A Generative AI security gateway."""

from importlib import metadata
import logging as python_logging
from importlib import metadata

from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config
from codegate.codegate_logging import setup_logging, LogFormat, LogLevel
from codegate.exceptions import ConfigurationError

try:
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import click

from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.codegate_logging import setup_logging, LogFormat, LogLevel
from codegate.server import init_app


Expand Down
2 changes: 1 addition & 1 deletion src/codegate/codegate_logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
from enum import Enum
import json
import logging
import sys
from enum import Enum
from typing import Any, Optional


Expand Down
6 changes: 2 additions & 4 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import yaml

from codegate.codegate_logging import setup_logging, LogFormat, LogLevel
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.exceptions import ConfigurationError
from codegate.prompts import PromptConfig

Expand Down Expand Up @@ -52,9 +52,7 @@ def __post_init__(self) -> None:
@staticmethod
def _load_default_prompts() -> PromptConfig:
"""Load default prompts from prompts/default.yaml."""
default_prompts_path = (
Path(__file__).parent.parent.parent / "prompts" / "default.yaml"
)
default_prompts_path = Path(__file__).parent.parent.parent / "prompts" / "default.yaml"
try:
return PromptConfig.from_file(default_prompts_path)
except Exception as e:
Expand Down
18 changes: 11 additions & 7 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CodeSnippet:
language: The programming language identifier (e.g., 'python', 'javascript')
code: The actual code content
"""

language: str
code: str

Expand All @@ -24,6 +25,7 @@ def __post_init__(self):
raise ValueError("Code must not be empty")
self.language = self.language.strip().lower()


@dataclass
class PipelineContext:
code_snippets: List[CodeSnippet] = field(default_factory=list)
Expand All @@ -35,20 +37,24 @@ def add_code_snippet(self, snippet: CodeSnippet):
def get_snippets_by_language(self, language: str) -> List[CodeSnippet]:
return [s for s in self.code_snippets if s.language.lower() == language.lower()]


@dataclass
class PipelineResponse:
"""Response generated by a pipeline step"""

content: str
step_name: str # The name of the pipeline step that generated this response
model: str # Taken from the original request's model field


@dataclass
class PipelineResult:
"""
Represents the result of a pipeline operation.
Either contains a modified request to continue processing,
or a response to return to the client.
"""

request: Optional[ChatCompletionRequest] = None
response: Optional[PipelineResponse] = None
error_message: Optional[str] = None
Expand Down Expand Up @@ -79,8 +85,8 @@ def name(self) -> str:

@staticmethod
def get_last_user_message(
request: ChatCompletionRequest,
) -> Optional[tuple[str, int]]:
request: ChatCompletionRequest,
) -> Optional[tuple[str, int]]:
"""
Get the last user message and its index from the request.

Expand Down Expand Up @@ -122,9 +128,7 @@ def get_last_user_message(

@abstractmethod
async def process(
self,
request: ChatCompletionRequest,
context: PipelineContext
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""Process a request and return either modified request or response stream"""
pass
Expand All @@ -135,8 +139,8 @@ def __init__(self, pipeline_steps: List[PipelineStep]):
self.pipeline_steps = pipeline_steps

async def process_request(
self,
request: ChatCompletionRequest,
self,
request: ChatCompletionRequest,
) -> PipelineResult:
"""
Process a request through all pipeline steps
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/pipeline/version/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def name(self) -> str:
return "codegate-version"

async def process(
self,
request: ChatCompletionRequest,
context: PipelineContext
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Checks if the last user message contains "codegate-version" and
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig":
# Validate all values are strings
for key, value in prompt_data.items():
if not isinstance(value, str):
raise ConfigurationError(
f"Prompt '{key}' must be a string, got {type(value)}"
)
raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}")

return cls(prompts=prompt_data)
except yaml.YAMLError as e:
Expand Down
46 changes: 15 additions & 31 deletions src/codegate/providers/anthropic/adapter.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,30 @@
from typing import Any, Dict, Optional

from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse
from litellm.adapters.anthropic_adapter import (
AnthropicAdapter as LitellmAnthropicAdapter,
)
from litellm.types.llms.anthropic import AnthropicResponse

from codegate.providers.base import StreamGenerator
from codegate.providers.litellmshim import anthropic_stream_generator, BaseAdapter
from codegate.providers.litellmshim.adapter import (
LiteLLMAdapterInputNormalizer,
LiteLLMAdapterOutputNormalizer,
)


class AnthropicAdapter(BaseAdapter):
class AnthropicInputNormalizer(LiteLLMAdapterInputNormalizer):
"""
LiteLLM's adapter class interface is used to translate between the Anthropic data
format and the underlying model. The AnthropicAdapter class contains the actual
implementation of the interface methods, we just forward the calls to it.
"""

def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator):
self.litellm_anthropic_adapter = LitellmAnthropicAdapter()
super().__init__(stream_generator)
def __init__(self):
super().__init__(LitellmAnthropicAdapter())

def translate_completion_input_params(
self,
completion_request: Dict,
) -> Optional[ChatCompletionRequest]:
return self.litellm_anthropic_adapter.translate_completion_input_params(
completion_request
)

def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[AnthropicResponse]:
return self.litellm_anthropic_adapter.translate_completion_output_params(
response
)
class AnthropicOutputNormalizer(LiteLLMAdapterOutputNormalizer):
"""
LiteLLM's adapter class interface is used to translate between the Anthropic data
format and the underlying model. The AnthropicAdapter class contains the actual
implementation of the interface methods, we just forward the calls to it.
"""

def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> AdapterCompletionStreamWrapper | None:
return (
self.litellm_anthropic_adapter.translate_completion_output_params_streaming(
completion_stream
)
)
def __init__(self):
super().__init__(LitellmAnthropicAdapter())
14 changes: 9 additions & 5 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

from fastapi import Header, HTTPException, Request

from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.base import BaseProvider
from codegate.providers.litellmshim import LiteLLmShim
from codegate.providers.anthropic.adapter import AnthropicAdapter
from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator


class AnthropicProvider(BaseProvider):
def __init__(self, pipeline_processor=None):
adapter = AnthropicAdapter()
completion_handler = LiteLLmShim(adapter)
super().__init__(completion_handler, pipeline_processor)
completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator)
super().__init__(
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
)

@property
def provider_route_name(self) -> str:
Expand Down
Loading