Skip to content

Update the muxing rules to v3 #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""update matcher types

Revision ID: 5e5cd2288147
Revises: 0c3539f66339
Create Date: 2025-02-19 14:52:39.126196+00:00

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "5e5cd2288147"
down_revision: Union[str, None] = "0c3539f66339"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

# Update the matcher types. We need to do this every time we change the matcher types.
# in /muxing/models.py
op.execute(
"""
UPDATE muxes
SET matcher_type = 'fim_filename', matcher_blob = ''
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_type = 'chat_filename', matcher_blob = ''
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat';
"""
)

# Finish transaction
op.execute("COMMIT;")


def downgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
UPDATE muxes
SET matcher_blob = 'fim', matcher_type = 'request_type_match'
WHERE matcher_type = 'fim';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_blob = 'chat', matcher_type = 'request_type_match'
WHERE matcher_type = 'chat';
"""
)

# Finish transaction
op.execute("COMMIT;")
33 changes: 29 additions & 4 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
from enum import Enum
from typing import Optional
from typing import Optional, Self

import pydantic

from codegate.clients.clients import ClientType
from codegate.db.models import MuxRule as DBMuxRule


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.

The 3 rules present match filenames and request types. They're used in conjunction with the
matcher field in the MuxRule model.
E.g.
- catch_all-> Always match
- filename_match and match: requests.py -> Match the request if the filename is requests.py
- fim_filename and match: main.py -> Match the request if the request type is fim
and the filename is main.py

NOTE: Removing or updating fields from this enum will require a migration.
Adding new fields is safe.
"""

# Always match this prompt
catch_all = "catch_all"
# Match based on the filename. It will match if there is a filename
# in the request that matches the matcher either extension or full name (*.py or main.py)
filename_match = "filename_match"
# Match based on the request type. It will match if the request type
# matches the matcher (e.g. FIM or chat)
request_type_match = "request_type_match"
# Match based on fim request type. It will match if the request type is fim
fim_filename = "fim_filename"
# Match based on chat request type. It will match if the request type is chat
chat_filename = "chat_filename"


class MuxRule(pydantic.BaseModel):
Expand All @@ -36,6 +49,18 @@ class MuxRule(pydantic.BaseModel):
# this depends on the matcher type.
matcher: Optional[str] = None

@classmethod
def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self:
"""
Convert a DBMuxRule to a MuxRule.
"""
return MuxRule(
provider_id=db_mux_rule.id,
model=db_mux_rule.provider_model_name,
matcher_type=db_mux_rule.matcher_type,
matcher=db_mux_rule.matcher_blob,
)


class ThingToMatchMux(pydantic.BaseModel):
"""
Expand Down
5 changes: 4 additions & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ async def _get_model_route(
# Try to get a model route for the active workspace
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
return model_route
except rulematcher.MuxMatchingError as e:
logger.exception(f"Error matching rule and getting model route: {e}")
raise HTTPException(detail=str(e), status_code=404)
except Exception as e:
logger.error(f"Error getting active workspace muxes: {e}")
logger.exception(f"Error getting active workspace muxes: {e}")
raise HTTPException(detail=str(e), status_code=404)

def _setup_routes(self):
Expand Down
88 changes: 56 additions & 32 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
_singleton_lock = Lock()


class MuxMatchingError(Exception):
"""An exception for muxing matching errors."""

pass


async def get_muxing_rules_registry():
"""Returns a singleton instance of the muxing rules registry."""

Expand Down Expand Up @@ -48,9 +54,9 @@ def __init__(
class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""

def __init__(self, route: ModelRoute, matcher_blob: str):
def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
self._route = route
self._matcher_blob = matcher_blob
self._mux_rule = mux_rule

@abstractmethod
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
Expand All @@ -67,18 +73,20 @@ class MuxingMatcherFactory:
"""Factory for creating muxing matchers."""

@staticmethod
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""

factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
}

try:
# Initialize the MuxingRuleMatcher
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule)
return factory[mux_rule.matcher_type](route, mux_rule)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")

Expand All @@ -103,47 +111,63 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
return body_extractor.extract_unique_filenames(data)
except BodyCodeSnippetExtractorError as e:
logger.error(f"Error extracting filenames from request: {e}")
return set()
raise MuxMatchingError("Error extracting filenames from request")

def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool:
"""
Check if the matcher is in the request filenames.
"""
# Empty matcher_blob means we match everything
if not self._mux_rule.matcher:
return True
filenames_to_match = self._extract_request_filenames(detected_client, data)
# _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
# match the rule.
is_filename_match = any(
self._mux_rule.matcher == filename or filename.endswith(self._mux_rule.matcher)
for filename in filenames_to_match
)
return is_filename_match

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Retun True if there is a filename in the request that matches the matcher_blob.
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
Return True if the matcher is in one of the request filenames.
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
filenames_to_match = self._extract_request_filenames(
is_rule_matched = self._is_matcher_in_filenames(
thing_to_match.client_type, thing_to_match.body
)
is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
if is_filename_match:
logger.info(
"Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
)
return is_filename_match
if is_rule_matched:
logger.info("Filename rule matched", matcher=self._mux_rule.matcher)
return is_rule_matched


class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""
class RequestTypeAndFileMuxingRuleMatcher(FileMuxingRuleMatcher):
"""A request type and file muxing rule matcher."""

def _is_request_type_match(self, is_fim_request: bool) -> bool:
"""
Check if the request type matches the MuxMatcherType.
"""
incoming_request_type = "fim_filename" if is_fim_request else "chat_filename"
if incoming_request_type == self._mux_rule.matcher_type:
return True
return False

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the request type matches the matcher_blob.
The matcher_blob is either "fim" or "chat".
Return True if the matcher is in one of the request filenames and
if the request type matches the MuxMatcherType.
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
is_request_type_match = self._matcher_blob == incoming_request_type
if is_request_type_match:
is_rule_matched = self._is_matcher_in_filenames(
thing_to_match.client_type, thing_to_match.body
) and self._is_request_type_match(thing_to_match.is_fim_request)
if is_rule_matched:
logger.info(
"Request type rule matched",
matcher=self._matcher_blob,
request_type=incoming_request_type,
"Request type and rule matched",
matcher=self._mux_rule.matcher,
is_fim_request=thing_to_match.is_fim_request,
)
return is_request_type_match
return is_rule_matched


class MuxingRulesinWorkspaces:
Expand Down
Loading
Loading