Skip to content

Commit a68db74

Browse files
Update the muxing rules to v3
Closes: #1060 Right now the muxing rules are designed to catch globally FIM or Chat requests. This PR extends its functionality to be able to match per file and request, i.e. this PR enables - Chat request of main.py -> model 1 - FIM request of main.py -> model 2 - Any type of v1.py -> model 3
1 parent 9555a03 commit a68db74

File tree

5 files changed

+204
-135
lines changed

5 files changed

+204
-135
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""update matcher types
2+
3+
Revision ID: 5e5cd2288147
4+
Revises: 0c3539f66339
5+
Create Date: 2025-02-19 14:52:39.126196+00:00
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision: str = "5e5cd2288147"
15+
down_revision: Union[str, None] = "0c3539f66339"
16+
branch_labels: Union[str, Sequence[str], None] = None
17+
depends_on: Union[str, Sequence[str], None] = None
18+
19+
20+
def upgrade() -> None:
21+
# Begin transaction
22+
op.execute("BEGIN TRANSACTION;")
23+
24+
# Update the matcher types. We need to do this every time we change the matcher types.
25+
# in /muxing/models.py
26+
op.execute(
27+
"""
28+
UPDATE muxes
29+
SET matcher_type = 'fim', matcher_blob = ''
30+
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim';
31+
"""
32+
)
33+
op.execute(
34+
"""
35+
UPDATE muxes
36+
SET matcher_type = 'chat', matcher_blob = ''
37+
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat';
38+
"""
39+
)
40+
op.execute(
41+
"""
42+
UPDATE muxes
43+
SET matcher_type = 'catch_all'
44+
WHERE matcher_type = 'filename_match' AND matcher_blob != '';
45+
"""
46+
)
47+
48+
# Finish transaction
49+
op.execute("COMMIT;")
50+
51+
52+
def downgrade() -> None:
53+
# Begin transaction
54+
op.execute("BEGIN TRANSACTION;")
55+
56+
op.execute(
57+
"""
58+
UPDATE muxes
59+
SET matcher_blob = 'fim', matcher_type = 'request_type_match'
60+
WHERE matcher_type = 'fim';
61+
"""
62+
)
63+
op.execute(
64+
"""
65+
UPDATE muxes
66+
SET matcher_blob = 'chat', matcher_type = 'request_type_match'
67+
WHERE matcher_type = 'chat';
68+
"""
69+
)
70+
op.execute(
71+
"""
72+
UPDATE muxes
73+
SET matcher_type = 'filename_match', matcher_blob = 'catch_all'
74+
WHERE matcher_type = 'catch_all';
75+
"""
76+
)
77+
78+
# Finish transaction
79+
op.execute("COMMIT;")

Diff for: src/codegate/muxing/models.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
11
from enum import Enum
2-
from typing import Optional
2+
from typing import Optional, Self
33

44
import pydantic
55

66
from codegate.clients.clients import ClientType
7+
from codegate.db.models import MuxRule as DBMuxRule
78

89

910
class MuxMatcherType(str, Enum):
1011
"""
1112
Represents the different types of matchers we support.
13+
14+
The 3 rules present match filenames and request types. They're used in conjunction with the
15+
matcher field in the MuxRule model.
16+
E.g.
17+
- catch_all and match: None -> Always match
18+
- fim and match: requests.py -> Match the request if the filename is requests.py and FIM
19+
- chat and match: None -> Match the request if it's a chat request
20+
- chat and match: .js -> Match the request if the filename has a .js extension and is chat
21+
22+
NOTE: Removing or updating fields from this enum will require a migration.
1223
"""
1324

1425
# Always match this prompt
1526
catch_all = "catch_all"
16-
# Match based on the filename. It will match if there is a filename
17-
# in the request that matches the matcher either extension or full name (*.py or main.py)
18-
filename_match = "filename_match"
19-
# Match based on the request type. It will match if the request type
20-
# matches the matcher (e.g. FIM or chat)
21-
request_type_match = "request_type_match"
27+
# Match based on fim request type. It will match if the request type is fim
28+
fim = "fim"
29+
# Match based on chat request type. It will match if the request type is chat
30+
chat = "chat"
2231

2332

2433
class MuxRule(pydantic.BaseModel):
@@ -36,6 +45,18 @@ class MuxRule(pydantic.BaseModel):
3645
# this depends on the matcher type.
3746
matcher: Optional[str] = None
3847

48+
@classmethod
49+
def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self:
50+
"""
51+
Convert a DBMuxRule to a MuxRule.
52+
"""
53+
return MuxRule(
54+
provider_id=db_mux_rule.id,
55+
model=db_mux_rule.provider_model_name,
56+
matcher_type=db_mux_rule.matcher_type,
57+
matcher=db_mux_rule.matcher_blob,
58+
)
59+
3960

4061
class ThingToMatchMux(pydantic.BaseModel):
4162
"""

Diff for: src/codegate/muxing/router.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ async def _get_model_route(
5050
# Try to get a model route for the active workspace
5151
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
5252
return model_route
53+
except rulematcher.MuxMatchingError as e:
54+
logger.exception(f"Error matching rule and getting model route: {e}")
55+
raise HTTPException(detail=str(e), status_code=404)
5356
except Exception as e:
54-
logger.error(f"Error getting active workspace muxes: {e}")
57+
logger.exception(f"Error getting active workspace muxes: {e}")
5558
raise HTTPException(detail=str(e), status_code=404)
5659

5760
def _setup_routes(self):

Diff for: src/codegate/muxing/rulematcher.py

+48-46
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
_singleton_lock = Lock()
1919

2020

21+
class MuxMatchingError(Exception):
22+
"""An exception for muxing matching errors."""
23+
24+
pass
25+
26+
2127
async def get_muxing_rules_registry():
2228
"""Returns a singleton instance of the muxing rules registry."""
2329

@@ -48,9 +54,9 @@ def __init__(
4854
class MuxingRuleMatcher(ABC):
4955
"""Base class for matching muxing rules."""
5056

51-
def __init__(self, route: ModelRoute, matcher_blob: str):
57+
def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
5258
self._route = route
53-
self._matcher_blob = matcher_blob
59+
self._mux_rule = mux_rule
5460

5561
@abstractmethod
5662
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
@@ -67,32 +73,24 @@ class MuxingMatcherFactory:
6773
"""Factory for creating muxing matchers."""
6874

6975
@staticmethod
70-
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
76+
def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
7177
"""Create a muxing matcher for the given endpoint and model."""
7278

7379
factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
74-
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
75-
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
76-
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
80+
mux_models.MuxMatcherType.catch_all: RequestTypeAndFileMuxingRuleMatcher,
81+
mux_models.MuxMatcherType.fim: RequestTypeAndFileMuxingRuleMatcher,
82+
mux_models.MuxMatcherType.chat: RequestTypeAndFileMuxingRuleMatcher,
7783
}
7884

7985
try:
8086
# Initialize the MuxingRuleMatcher
81-
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
87+
mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule)
88+
return factory[mux_rule.matcher_type](route, mux_rule)
8289
except KeyError:
8390
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
8491

8592

86-
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
87-
"""A catch all muxing rule matcher."""
88-
89-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
90-
logger.info("Catch all rule matched")
91-
return True
92-
93-
94-
class FileMuxingRuleMatcher(MuxingRuleMatcher):
95-
"""A file muxing rule matcher."""
93+
class RequestTypeAndFileMuxingRuleMatcher(MuxingRuleMatcher):
9694

9795
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
9896
"""
@@ -103,47 +101,51 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
103101
return body_extractor.extract_unique_filenames(data)
104102
except BodyCodeSnippetExtractorError as e:
105103
logger.error(f"Error extracting filenames from request: {e}")
106-
return set()
104+
raise MuxMatchingError("Error extracting filenames from request")
107105

108-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
106+
def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool:
109107
"""
110-
Retun True if there is a filename in the request that matches the matcher_blob.
111-
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
108+
Check if the matcher is in the request filenames.
112109
"""
113-
# If there is no matcher_blob, we don't match
114-
if not self._matcher_blob:
115-
return False
116-
filenames_to_match = self._extract_request_filenames(
117-
thing_to_match.client_type, thing_to_match.body
110+
# Empty matcher_blob means we match everything
111+
if not self._mux_rule.matcher:
112+
return True
113+
filenames_to_match = self._extract_request_filenames(detected_client, data)
114+
# _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
115+
# match the rule.
116+
is_filename_match = any(
117+
self._mux_rule.matcher == filename or filename.endswith(self._mux_rule.matcher)
118+
for filename in filenames_to_match
118119
)
119-
is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
120-
if is_filename_match:
121-
logger.info(
122-
"Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
123-
)
124120
return is_filename_match
125121

126-
127-
class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
128-
"""A catch all muxing rule matcher."""
122+
def _is_request_type_match(self, is_fim_request: bool) -> bool:
123+
"""
124+
Check if the request type matches the MuxMatcherType.
125+
"""
126+
# Catch all rule matches both chat and FIM requests
127+
if self._mux_rule.matcher_type == mux_models.MuxMatcherType.catch_all:
128+
return True
129+
incoming_request_type = "fim" if is_fim_request else "chat"
130+
if incoming_request_type == self._mux_rule.matcher_type:
131+
return True
132+
return False
129133

130134
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
131135
"""
132-
Return True if the request type matches the matcher_blob.
133-
The matcher_blob is either "fim" or "chat".
136+
Return True if the matcher is in one of the request filenames and
137+
if the request type matches the MuxMatcherType.
134138
"""
135-
# If there is no matcher_blob, we don't match
136-
if not self._matcher_blob:
137-
return False
138-
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
139-
is_request_type_match = self._matcher_blob == incoming_request_type
140-
if is_request_type_match:
139+
is_rule_matched = self._is_matcher_in_filenames(
140+
thing_to_match.client_type, thing_to_match.body
141+
) and self._is_request_type_match(thing_to_match.is_fim_request)
142+
if is_rule_matched:
141143
logger.info(
142-
"Request type rule matched",
143-
matcher=self._matcher_blob,
144-
request_type=incoming_request_type,
144+
"Request type and rule matched",
145+
matcher=self._mux_rule.matcher,
146+
is_fim_request=thing_to_match.is_fim_request,
145147
)
146-
return is_request_type_match
148+
return is_rule_matched
147149

148150

149151
class MuxingRulesinWorkspaces:

0 commit comments

Comments
 (0)