18
18
_singleton_lock = Lock ()
19
19
20
20
21
+ class MuxMatchingError (Exception ):
22
+ """An exception for muxing matching errors."""
23
+
24
+ pass
25
+
26
+
21
27
async def get_muxing_rules_registry ():
22
28
"""Returns a singleton instance of the muxing rules registry."""
23
29
@@ -48,9 +54,9 @@ def __init__(
48
54
class MuxingRuleMatcher (ABC ):
49
55
"""Base class for matching muxing rules."""
50
56
51
- def __init__ (self , route : ModelRoute , matcher_blob : str ):
57
+ def __init__ (self , route : ModelRoute , mux_rule : mux_models . MuxRule ):
52
58
self ._route = route
53
- self ._matcher_blob = matcher_blob
59
+ self ._mux_rule = mux_rule
54
60
55
61
@abstractmethod
56
62
def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
@@ -67,32 +73,24 @@ class MuxingMatcherFactory:
67
73
"""Factory for creating muxing matchers."""
68
74
69
75
@staticmethod
70
- def create (mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
76
+ def create (db_mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
71
77
"""Create a muxing matcher for the given endpoint and model."""
72
78
73
79
factory : Dict [mux_models .MuxMatcherType , MuxingRuleMatcher ] = {
74
- mux_models .MuxMatcherType .catch_all : CatchAllMuxingRuleMatcher ,
75
- mux_models .MuxMatcherType .filename_match : FileMuxingRuleMatcher ,
76
- mux_models .MuxMatcherType .request_type_match : RequestTypeMuxingRuleMatcher ,
80
+ mux_models .MuxMatcherType .catch_all : RequestTypeAndFileMuxingRuleMatcher ,
81
+ mux_models .MuxMatcherType .fim : RequestTypeAndFileMuxingRuleMatcher ,
82
+ mux_models .MuxMatcherType .chat : RequestTypeAndFileMuxingRuleMatcher ,
77
83
}
78
84
79
85
try :
80
86
# Initialize the MuxingRuleMatcher
81
- return factory [mux_rule .matcher_type ](route , mux_rule .matcher_blob )
87
+ mux_rule = mux_models .MuxRule .from_db_mux_rule (db_mux_rule )
88
+ return factory [mux_rule .matcher_type ](route , mux_rule )
82
89
except KeyError :
83
90
raise ValueError (f"Unknown matcher type: { mux_rule .matcher_type } " )
84
91
85
92
86
- class CatchAllMuxingRuleMatcher (MuxingRuleMatcher ):
87
- """A catch all muxing rule matcher."""
88
-
89
- def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
90
- logger .info ("Catch all rule matched" )
91
- return True
92
-
93
-
94
- class FileMuxingRuleMatcher (MuxingRuleMatcher ):
95
- """A file muxing rule matcher."""
93
+ class RequestTypeAndFileMuxingRuleMatcher (MuxingRuleMatcher ):
96
94
97
95
def _extract_request_filenames (self , detected_client : ClientType , data : dict ) -> set [str ]:
98
96
"""
@@ -103,47 +101,51 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
103
101
return body_extractor .extract_unique_filenames (data )
104
102
except BodyCodeSnippetExtractorError as e :
105
103
logger .error (f"Error extracting filenames from request: { e } " )
106
- return set ( )
104
+ raise MuxMatchingError ( "Error extracting filenames from request" )
107
105
108
- def match (self , thing_to_match : mux_models . ThingToMatchMux ) -> bool :
106
+ def _is_matcher_in_filenames (self , detected_client : ClientType , data : dict ) -> bool :
109
107
"""
110
- Retun True if there is a filename in the request that matches the matcher_blob.
111
- The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
108
+ Check if the matcher is in the request filenames.
112
109
"""
113
- # If there is no matcher_blob, we don't match
114
- if not self ._matcher_blob :
115
- return False
116
- filenames_to_match = self ._extract_request_filenames (
117
- thing_to_match .client_type , thing_to_match .body
110
+ # Empty matcher_blob means we match everything
111
+ if not self ._mux_rule .matcher :
112
+ return True
113
+ filenames_to_match = self ._extract_request_filenames (detected_client , data )
114
+ # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
115
+ # match the rule.
116
+ is_filename_match = any (
117
+ self ._mux_rule .matcher == filename or filename .endswith (self ._mux_rule .matcher )
118
+ for filename in filenames_to_match
118
119
)
119
- is_filename_match = any (self ._matcher_blob in filename for filename in filenames_to_match )
120
- if is_filename_match :
121
- logger .info (
122
- "Filename rule matched" , filenames = filenames_to_match , matcher = self ._matcher_blob
123
- )
124
120
return is_filename_match
125
121
126
-
127
- class RequestTypeMuxingRuleMatcher (MuxingRuleMatcher ):
128
- """A catch all muxing rule matcher."""
122
+ def _is_request_type_match (self , is_fim_request : bool ) -> bool :
123
+ """
124
+ Check if the request type matches the MuxMatcherType.
125
+ """
126
+ # Catch all rule matches both chat and FIM requests
127
+ if self ._mux_rule .matcher_type == mux_models .MuxMatcherType .catch_all :
128
+ return True
129
+ incoming_request_type = "fim" if is_fim_request else "chat"
130
+ if incoming_request_type == self ._mux_rule .matcher_type :
131
+ return True
132
+ return False
129
133
130
134
def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
131
135
"""
132
- Return True if the request type matches the matcher_blob.
133
- The matcher_blob is either "fim" or "chat" .
136
+ Return True if the matcher is in one of the request filenames and
137
+ if the request type matches the MuxMatcherType .
134
138
"""
135
- # If there is no matcher_blob, we don't match
136
- if not self ._matcher_blob :
137
- return False
138
- incoming_request_type = "fim" if thing_to_match .is_fim_request else "chat"
139
- is_request_type_match = self ._matcher_blob == incoming_request_type
140
- if is_request_type_match :
139
+ is_rule_matched = self ._is_matcher_in_filenames (
140
+ thing_to_match .client_type , thing_to_match .body
141
+ ) and self ._is_request_type_match (thing_to_match .is_fim_request )
142
+ if is_rule_matched :
141
143
logger .info (
142
- "Request type rule matched" ,
143
- matcher = self ._matcher_blob ,
144
- request_type = incoming_request_type ,
144
+ "Request type and rule matched" ,
145
+ matcher = self ._mux_rule . matcher ,
146
+ is_fim_request = thing_to_match . is_fim_request ,
145
147
)
146
- return is_request_type_match
148
+ return is_rule_matched
147
149
148
150
149
151
class MuxingRulesinWorkspaces :
0 commit comments