Skip to content

Commit 3f03b4f

Browse files
committed
Add scaffolding of API for provider endpoints and mux rules
This providers a scaffolding to start discussions and start drafting the provider endpoints implementation and mux rules. Closes: #753 Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 6999136 commit 3f03b4f

File tree

2 files changed

+201
-0
lines changed

2 files changed

+201
-0
lines changed

src/codegate/api/v1.py

+121
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,87 @@ def uniq_name(route: APIRoute):
2525
return f"v1_{route.name}"
2626

2727

28+
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29+
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
30+
"""List all provider endpoints."""
31+
# NOTE: This is a dummy implementation. In the future, we should have a proper
32+
# implementation that fetches the provider endpoints from the database.
33+
return [
34+
v1_models.ProviderEndpoint(
35+
id=1,
36+
name="dummy",
37+
description="Dummy provider endpoint",
38+
endpoint="http://example.com",
39+
provider_type=v1_models.ProviderType.openai,
40+
auth_type=v1_models.ProviderAuthType.none,
41+
)
42+
]
43+
44+
45+
@v1.get(
46+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
47+
)
48+
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
49+
"""Get a provider endpoint by ID."""
50+
# NOTE: This is a dummy implementation. In the future, we should have a proper
51+
# implementation that fetches the provider endpoint from the database.
52+
return v1_models.ProviderEndpoint(
53+
id=provider_id,
54+
name="dummy",
55+
description="Dummy provider endpoint",
56+
endpoint="http://example.com",
57+
provider_type=v1_models.ProviderType.openai,
58+
auth_type=v1_models.ProviderAuthType.none,
59+
)
60+
61+
62+
@v1.post(
63+
"/provider-endpoints",
64+
tags=["Providers"],
65+
generate_unique_id_function=uniq_name,
66+
status_code=201,
67+
)
68+
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
69+
"""Add a provider endpoint."""
70+
# NOTE: This is a dummy implementation. In the future, we should have a proper
71+
# implementation that adds the provider endpoint to the database.
72+
return request
73+
74+
75+
@v1.put(
76+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
77+
)
78+
async def update_provider_endpoint(
79+
provider_id: int, request: v1_models.ProviderEndpoint
80+
) -> v1_models.ProviderEndpoint:
81+
"""Update a provider endpoint by ID."""
82+
# NOTE: This is a dummy implementation. In the future, we should have a proper
83+
# implementation that updates the provider endpoint in the database.
84+
return request
85+
86+
87+
@v1.delete(
88+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
89+
)
90+
async def delete_provider_endpoint(provider_id: int):
91+
"""Delete a provider endpoint by id."""
92+
# NOTE: This is a dummy implementation. In the future, we should have a proper
93+
# implementation that deletes the provider endpoint from the database.
94+
return Response(status_code=204)
95+
96+
97+
@v1.get(
98+
"/provider-endpoints/{provider_name}/models",
99+
tags=["Providers"],
100+
generate_unique_id_function=uniq_name,
101+
)
102+
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
103+
"""List models by provider."""
104+
# NOTE: This is a dummy implementation. In the future, we should have a proper
105+
# implementation that fetches the models by provider from the database.
106+
return list(v1_models.ModelByProvider(name="dummy", provider="dummy"))
107+
108+
28109
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
29110
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
30111
"""List all workspaces."""
@@ -296,6 +377,46 @@ async def delete_workspace_custom_instructions(workspace_name: str):
296377
return Response(status_code=204)
297378

298379

380+
@v1.get(
381+
"/workspaces/{workspace_name}/muxes",
382+
tags=["Workspaces", "Muxes"],
383+
generate_unique_id_function=uniq_name,
384+
)
385+
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
386+
"""Get the mux rules of a workspace.
387+
388+
The list is ordered in order of priority. That is, the first rule in the list
389+
has the highest priority."""
390+
# TODO: This is a dummy implementation. In the future, we should have a proper
391+
# implementation that fetches the mux rules from the database.
392+
return [
393+
v1_models.MuxRule(
394+
provider="openai",
395+
model="gpt-3.5-turbo",
396+
matcher_type=v1_models.MuxMatcherType.file_regex,
397+
matcher=".*\\.txt",
398+
),
399+
v1_models.MuxRule(
400+
provider="anthropic",
401+
model="davinci",
402+
matcher_type=v1_models.MuxMatcherType.catch_all,
403+
),
404+
]
405+
406+
407+
@v1.put(
408+
"/workspaces/{workspace_name}/muxes",
409+
tags=["Workspaces", "Muxes"],
410+
generate_unique_id_function=uniq_name,
411+
status_code=204,
412+
)
413+
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
414+
"""Set the mux rules of a workspace."""
415+
# TODO: This is a dummy implementation. In the future, we should have a proper
416+
# implementation that sets the mux rules in the database.
417+
return Response(status_code=204)
418+
419+
299420
@v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name)
300421
async def stream_sse():
301422
"""

src/codegate/api/v1_models.py

+80
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,83 @@ class AlertConversation(pydantic.BaseModel):
138138
trigger_type: str
139139
trigger_category: Optional[str]
140140
timestamp: datetime.datetime
141+
142+
143+
class ProviderType(str, Enum):
144+
"""
145+
Represents the different types of providers we support.
146+
"""
147+
148+
openai = "openai"
149+
anthropic = "anthropic"
150+
vllm = "vllm"
151+
152+
153+
class ProviderAuthType(str, Enum):
154+
"""
155+
Represents the different types of auth we support for providers.
156+
"""
157+
158+
# No auth required
159+
none = "none"
160+
# Whatever the user provides is passed through
161+
passthrough = "passthrough"
162+
# API key is required
163+
api_key = "api_key"
164+
165+
166+
class ProviderEndpoint(pydantic.BaseModel):
167+
"""
168+
Represents a provider's endpoint configuration. This
169+
allows us to persist the configuration for each provider,
170+
so we can use this for muxing messages.
171+
"""
172+
173+
id: int
174+
name: str
175+
description: str = ""
176+
provider_type: ProviderType
177+
endpoint: str
178+
auth_type: ProviderAuthType
179+
180+
181+
class ModelByProvider(pydantic.BaseModel):
182+
"""
183+
Represents a model supported by a provider.
184+
185+
Note that these are auto-discovered by the provider.
186+
"""
187+
188+
name: str
189+
provider: str
190+
191+
def __str__(self):
192+
return f"{self.provider}/{self.name}"
193+
194+
195+
class MuxMatcherType(str, Enum):
196+
"""
197+
Represents the different types of matchers we support.
198+
"""
199+
200+
# Match a regular expression for a file path
201+
# in the prompt. Note that if no file is found,
202+
# the prompt will be passed through.
203+
file_regex = "file_regex"
204+
205+
# Always match this prompt
206+
catch_all = "catch_all"
207+
208+
209+
class MuxRule(pydantic.BaseModel):
210+
"""
211+
Represents a mux rule for a provider.
212+
"""
213+
214+
provider: str
215+
model: str
216+
# The type of matcher to use
217+
matcher_type: MuxMatcherType
218+
# The actual matcher to use. Note that
219+
# this depends on the matcher type.
220+
matcher: Optional[str]

0 commit comments

Comments
 (0)