Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 3e53a20

Browse files
authored
Merge branch 'main' into issue-875
2 parents 3d4fd0e + 9d73e50 commit 3e53a20

File tree

27 files changed

+705
-219
lines changed

27 files changed

+705
-219
lines changed

api/openapi.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1581,7 +1581,9 @@
15811581
"MuxMatcherType": {
15821582
"type": "string",
15831583
"enum": [
1584-
"catch_all"
1584+
"catch_all",
1585+
"filename_match",
1586+
"request_type_match"
15851587
],
15861588
"title": "MuxMatcherType",
15871589
"description": "Represents the different types of matchers we support."

src/codegate/clients/clients.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ class ClientType(Enum):
1212
COPILOT = "copilot" # Copilot client
1313
OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client
1414
AIDER = "aider" # Aider client
15+
CONTINUE = "continue" # Continue client

src/codegate/clients/detector.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ def client_name(self) -> ClientType:
160160
return ClientType.OPEN_INTERPRETER
161161

162162

163+
class ContinueDetector(BaseClientDetector):
164+
"""
165+
Detector for Continue client based on message content
166+
"""
167+
168+
def __init__(self):
169+
super().__init__()
170+
# This is a hack that really only detects Continue with DeepSeek
171+
# we should get a header or user agent for this (upstream PR pending)
172+
self.content_detector = ContentDetector(
173+
"You are an AI programming assistant, utilizing the DeepSeek Coder model"
174+
)
175+
176+
@property
177+
def client_name(self) -> ClientType:
178+
return ClientType.CONTINUE
179+
180+
163181
class CopilotDetector(BaseClientDetector):
164182
"""
165183
Detector for Copilot client based on user agent
@@ -191,6 +209,7 @@ def __init__(self):
191209
KoduDetector(),
192210
OpenInterpreter(),
193211
CopilotDetector(),
212+
ContinueDetector(),
194213
]
195214

196215
def __call__(self, func):

src/codegate/muxing/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import pydantic
55

6+
from codegate.clients.clients import ClientType
7+
68

79
class MuxMatcherType(str, Enum):
810
"""
@@ -11,6 +13,12 @@ class MuxMatcherType(str, Enum):
1113

1214
# Always match this prompt
1315
catch_all = "catch_all"
16+
# Match based on the filename. It will match if there is a filename
17+
# in the request that matches the matcher either extension or full name (*.py or main.py)
18+
filename_match = "filename_match"
19+
# Match based on the request type. It will match if the request type
20+
# matches the matcher (e.g. FIM or chat)
21+
request_type_match = "request_type_match"
1422

1523

1624
class MuxRule(pydantic.BaseModel):
@@ -25,3 +33,14 @@ class MuxRule(pydantic.BaseModel):
2533
# The actual matcher to use. Note that
2634
# this depends on the matcher type.
2735
matcher: Optional[str] = None
36+
37+
38+
class ThingToMatchMux(pydantic.BaseModel):
39+
"""
40+
Represents the fields we can use to match a mux rule.
41+
"""
42+
43+
body: dict
44+
url_request_path: str
45+
is_fim_request: bool
46+
client_type: ClientType

src/codegate/muxing/router.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import json
2+
from typing import Optional
23

34
import structlog
45
from fastapi import APIRouter, HTTPException, Request
56

6-
from codegate.clients.clients import ClientType
77
from codegate.clients.detector import DetectClient
8-
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
9-
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
8+
from codegate.muxing import models as mux_models
109
from codegate.muxing import rulematcher
1110
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
11+
from codegate.providers.fim_analyzer import FIMAnalyzer
1212
from codegate.providers.registry import ProviderRegistry
1313
from codegate.workspaces.crud import WorkspaceCrud
1414

@@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter:
3939
def _ensure_path_starts_with_slash(self, path: str) -> str:
4040
return path if path.startswith("/") else f"/{path}"
4141

42-
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
42+
async def _get_model_route(
43+
self, thing_to_match: mux_models.ThingToMatchMux
44+
) -> Optional[rulematcher.ModelRoute]:
4345
"""
44-
Extract filenames from the request data.
46+
Get the model route for the given things_to_match.
4547
"""
46-
try:
47-
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
48-
return body_extractor.extract_unique_filenames(data)
49-
except BodyCodeSnippetExtractorError as e:
50-
logger.error(f"Error extracting filenames from request: {e}")
51-
return set()
52-
53-
async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]:
54-
"""
55-
Get the model routes for the given filenames.
56-
"""
57-
model_routes = []
5848
mux_registry = await rulematcher.get_muxing_rules_registry()
5949
try:
60-
# Try to get a catch_all route
61-
single_model_route = await mux_registry.get_match_for_active_workspace(
62-
thing_to_match=None
63-
)
64-
model_routes.append(single_model_route)
65-
66-
# Get the model routes for each filename
67-
for filename in filenames:
68-
model_route = await mux_registry.get_match_for_active_workspace(
69-
thing_to_match=filename
70-
)
71-
model_routes.append(model_route)
50+
# Try to get a model route for the active workspace
51+
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
52+
return model_route
7253
except Exception as e:
7354
logger.error(f"Error getting active workspace muxes: {e}")
7455
raise HTTPException(str(e), status_code=404)
75-
return model_routes
7656

7757
def _setup_routes(self):
7858

@@ -88,34 +68,45 @@ async def route_to_dest_provider(
8868
1. Get destination provider from DB and active workspace.
8969
2. Map the request body to the destination provider format.
9070
3. Run pipeline. Selecting the correct destination provider.
91-
4. Transmit the response back to the client in the correct format.
71+
4. Transmit the response back to the client in OpenAI format.
9272
"""
9373

9474
body = await request.body()
9575
data = json.loads(body)
76+
is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data)
77+
78+
# 1. Get destination provider from DB and active workspace.
79+
thing_to_match = mux_models.ThingToMatchMux(
80+
body=data,
81+
url_request_path=rest_of_path,
82+
is_fim_request=is_fim_request,
83+
client_type=request.state.detected_client,
84+
)
85+
model_route = await self._get_model_route(thing_to_match)
86+
if not model_route:
87+
raise HTTPException(
88+
"No matching rule found for the active workspace", status_code=404
89+
)
9690

97-
filenames_in_data = self._extract_request_filenames(request.state.detected_client, data)
98-
logger.info(f"Extracted filenames from request: {filenames_in_data}")
99-
100-
model_routes = await self._get_model_routes(filenames_in_data)
101-
if not model_routes:
102-
raise HTTPException("No rule found for the active workspace", status_code=404)
103-
104-
# We still need some logic here to handle the case where we have multiple model routes.
105-
# For the moment since we match all only pick the first.
106-
model_route = model_routes[0]
91+
logger.info(
92+
"Muxing request routed to destination provider",
93+
model=model_route.model.name,
94+
provider_type=model_route.endpoint.provider_type,
95+
provider_name=model_route.endpoint.name,
96+
)
10797

108-
# Parse the input data and map it to the destination provider format
98+
# 2. Map the request body to the destination provider format.
10999
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
110100
new_data = self._body_adapter.map_body_to_dest(model_route, data)
101+
102+
# 3. Run pipeline. Selecting the correct destination provider.
111103
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
112104
api_key = model_route.auth_material.auth_blob
113-
114-
# Send the request to the destination provider. It will run the pipeline
115105
response = await provider.process_request(
116-
new_data, api_key, rest_of_path, request.state.detected_client
106+
new_data, api_key, is_fim_request, request.state.detected_client
117107
)
118-
# Format the response to the client always using the OpenAI format
108+
109+
# 4. Transmit the response back to the client in OpenAI format.
119110
return self._response_adapter.format_response_to_client(
120111
response, model_route.endpoint.provider_type
121112
)

src/codegate/muxing/rulematcher.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import copy
22
from abc import ABC, abstractmethod
33
from asyncio import Lock
4-
from typing import List, Optional
4+
from typing import Dict, List, Optional
55

6+
import structlog
7+
8+
from codegate.clients.clients import ClientType
69
from codegate.db import models as db_models
10+
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
11+
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
12+
from codegate.muxing import models as mux_models
13+
14+
logger = structlog.get_logger("codegate")
715

816
_muxrules_sgtn = None
917

@@ -40,11 +48,12 @@ def __init__(
4048
class MuxingRuleMatcher(ABC):
4149
"""Base class for matching muxing rules."""
4250

43-
def __init__(self, route: ModelRoute):
51+
def __init__(self, route: ModelRoute, matcher_blob: str):
4452
self._route = route
53+
self._matcher_blob = matcher_blob
4554

4655
@abstractmethod
47-
def match(self, thing_to_match) -> bool:
56+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
4857
"""Return True if the rule matches the thing_to_match."""
4958
pass
5059

@@ -61,23 +70,82 @@ class MuxingMatcherFactory:
6170
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
6271
"""Create a muxing matcher for the given endpoint and model."""
6372

64-
factory = {
65-
"catch_all": CatchAllMuxingRuleMatcher,
73+
factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
74+
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
75+
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
76+
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
6677
}
6778

6879
try:
69-
return factory[mux_rule.matcher_type](route)
80+
# Initialize the MuxingRuleMatcher
81+
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
7082
except KeyError:
7183
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
7284

7385

7486
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
7587
"""A catch all muxing rule matcher."""
7688

77-
def match(self, thing_to_match) -> bool:
89+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
90+
logger.info("Catch all rule matched")
7891
return True
7992

8093

94+
class FileMuxingRuleMatcher(MuxingRuleMatcher):
95+
"""A file muxing rule matcher."""
96+
97+
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
98+
"""
99+
Extract filenames from the request data.
100+
"""
101+
try:
102+
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
103+
return body_extractor.extract_unique_filenames(data)
104+
except BodyCodeSnippetExtractorError as e:
105+
logger.error(f"Error extracting filenames from request: {e}")
106+
return set()
107+
108+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
109+
"""
110+
Retun True if there is a filename in the request that matches the matcher_blob.
111+
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
112+
"""
113+
# If there is no matcher_blob, we don't match
114+
if not self._matcher_blob:
115+
return False
116+
filenames_to_match = self._extract_request_filenames(
117+
thing_to_match.client_type, thing_to_match.body
118+
)
119+
is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
120+
if is_filename_match:
121+
logger.info(
122+
"Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
123+
)
124+
return is_filename_match
125+
126+
127+
class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
128+
"""A catch all muxing rule matcher."""
129+
130+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
131+
"""
132+
Return True if the request type matches the matcher_blob.
133+
The matcher_blob is either "fim" or "chat".
134+
"""
135+
# If there is no matcher_blob, we don't match
136+
if not self._matcher_blob:
137+
return False
138+
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
139+
is_request_type_match = self._matcher_blob == incoming_request_type
140+
if is_request_type_match:
141+
logger.info(
142+
"Request type rule matched",
143+
matcher=self._matcher_blob,
144+
request_type=incoming_request_type,
145+
)
146+
return is_request_type_match
147+
148+
81149
class MuxingRulesinWorkspaces:
82150
"""A thread safe dictionary to store the muxing rules in workspaces."""
83151

@@ -111,7 +179,9 @@ async def get_registries(self) -> List[str]:
111179
async with self._lock:
112180
return list(self._ws_rules.keys())
113181

114-
async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
182+
async def get_match_for_active_workspace(
183+
self, thing_to_match: mux_models.ThingToMatchMux
184+
) -> Optional[ModelRoute]:
115185
"""Get the first match for the given thing_to_match."""
116186

117187
# We iterate over all the rules and return the first match

src/codegate/pipeline/cli/cli.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
)
1414
from codegate.pipeline.cli.commands import CustomInstructions, Version, Workspace
1515

16+
codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)
17+
1618
HELP_TEXT = """
1719
## CodeGate CLI\n
1820
**Usage**: `codegate [-h] <command> [args]`\n
@@ -77,6 +79,22 @@ def _get_cli_from_open_interpreter(last_user_message_str: str) -> Optional[re.Ma
7779
return re.match(r"^codegate\s*(.*?)\s*$", last_user_block, re.IGNORECASE)
7880

7981

82+
def _get_cli_from_continue(last_user_message_str: str) -> Optional[re.Match[str]]:
83+
"""
84+
Continue sends a differently formatted message to the CLI if DeepSeek is used
85+
"""
86+
deepseek_match = re.search(
87+
r"utilizing the DeepSeek Coder model.*?### Instruction:\s*codegate\s+(.*?)\s*### Response:",
88+
last_user_message_str,
89+
re.DOTALL | re.IGNORECASE,
90+
)
91+
if deepseek_match:
92+
command = deepseek_match.group(1).strip()
93+
return re.match(r"^(.*?)$", command) # This creates a match object with the command
94+
95+
return codegate_regex.match(last_user_message_str)
96+
97+
8098
class CodegateCli(PipelineStep):
8199
"""Pipeline step that handles codegate cli."""
82100

@@ -110,12 +128,14 @@ async def process(
110128
if last_user_message is not None:
111129
last_user_message_str, _ = last_user_message
112130
last_user_message_str = last_user_message_str.strip()
113-
codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)
114131

132+
# Check client-specific matchers first
115133
if context.client in [ClientType.CLINE, ClientType.KODU]:
116134
match = _get_cli_from_cline(codegate_regex, last_user_message_str)
117135
elif context.client in [ClientType.OPEN_INTERPRETER]:
118136
match = _get_cli_from_open_interpreter(last_user_message_str)
137+
elif context.client in [ClientType.CONTINUE]:
138+
match = _get_cli_from_continue(last_user_message_str)
119139
else:
120140
# Check if "codegate" is the first word in the message
121141
match = codegate_regex.match(last_user_message_str)

0 commit comments

Comments
 (0)