Skip to content

Commit 1fff150

Browse files
committed
Add scaffolding of API for provider endpoints and mux rules
This providers a scaffolding to start discussions and start drafting the provider endpoints implementation and mux rules. Closes: #753 Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 6999136 commit 1fff150

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

src/codegate/api/v1.py

+133
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,99 @@ def uniq_name(route: APIRoute):
2525
return f"v1_{route.name}"
2626

2727

28+
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29+
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
30+
"""List all provider endpoints."""
31+
# NOTE: This is a dummy implementation. In the future, we should have a proper
32+
# implementation that fetches the provider endpoints from the database.
33+
return [
34+
v1_models.ProviderEndpoint(
35+
id=1,
36+
name="dummy",
37+
description="Dummy provider endpoint",
38+
endpoint="http://example.com",
39+
provider_type=v1_models.ProviderType.openai,
40+
auth_type=v1_models.ProviderAuthType.none,
41+
)
42+
]
43+
44+
45+
@v1.get(
46+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
47+
)
48+
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
49+
"""Get a provider endpoint by ID."""
50+
# NOTE: This is a dummy implementation. In the future, we should have a proper
51+
# implementation that fetches the provider endpoint from the database.
52+
return v1_models.ProviderEndpoint(
53+
id=provider_id,
54+
name="dummy",
55+
description="Dummy provider endpoint",
56+
endpoint="http://example.com",
57+
provider_type=v1_models.ProviderType.openai,
58+
auth_type=v1_models.ProviderAuthType.none,
59+
)
60+
61+
62+
@v1.post(
63+
"/provider-endpoints",
64+
tags=["Providers"],
65+
generate_unique_id_function=uniq_name,
66+
status_code=201,
67+
)
68+
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
69+
"""Add a provider endpoint."""
70+
# NOTE: This is a dummy implementation. In the future, we should have a proper
71+
# implementation that adds the provider endpoint to the database.
72+
return request
73+
74+
75+
@v1.put(
76+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
77+
)
78+
async def update_provider_endpoint(
79+
provider_id: int, request: v1_models.ProviderEndpoint
80+
) -> v1_models.ProviderEndpoint:
81+
"""Update a provider endpoint by ID."""
82+
# NOTE: This is a dummy implementation. In the future, we should have a proper
83+
# implementation that updates the provider endpoint in the database.
84+
return request
85+
86+
87+
@v1.delete(
88+
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
89+
)
90+
async def delete_provider_endpoint(provider_id: int):
91+
"""Delete a provider endpoint by id."""
92+
# NOTE: This is a dummy implementation. In the future, we should have a proper
93+
# implementation that deletes the provider endpoint from the database.
94+
return Response(status_code=204)
95+
96+
97+
@v1.get(
98+
"/provider-endpoints/{provider_name}/models",
99+
tags=["Providers"],
100+
generate_unique_id_function=uniq_name,
101+
)
102+
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
103+
"""List models by provider."""
104+
# NOTE: This is a dummy implementation. In the future, we should have a proper
105+
# implementation that fetches the models by provider from the database.
106+
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
107+
108+
109+
@v1.get(
110+
"/provider-endpoints/models",
111+
tags=["Providers"],
112+
generate_unique_id_function=uniq_name,
113+
)
114+
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
115+
"""List all models for all providers."""
116+
# NOTE: This is a dummy implementation. In the future, we should have a proper
117+
# implementation that fetches all the models for all providers from the database.
118+
return [v1_models.ModelByProvider(name="dummy", provider="dummy")]
119+
120+
28121
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
29122
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
30123
"""List all workspaces."""
@@ -296,6 +389,46 @@ async def delete_workspace_custom_instructions(workspace_name: str):
296389
return Response(status_code=204)
297390

298391

392+
@v1.get(
393+
"/workspaces/{workspace_name}/muxes",
394+
tags=["Workspaces", "Muxes"],
395+
generate_unique_id_function=uniq_name,
396+
)
397+
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
398+
"""Get the mux rules of a workspace.
399+
400+
The list is ordered in order of priority. That is, the first rule in the list
401+
has the highest priority."""
402+
# TODO: This is a dummy implementation. In the future, we should have a proper
403+
# implementation that fetches the mux rules from the database.
404+
return [
405+
v1_models.MuxRule(
406+
provider="openai",
407+
model="gpt-3.5-turbo",
408+
matcher_type=v1_models.MuxMatcherType.file_regex,
409+
matcher=".*\\.txt",
410+
),
411+
v1_models.MuxRule(
412+
provider="anthropic",
413+
model="davinci",
414+
matcher_type=v1_models.MuxMatcherType.catch_all,
415+
),
416+
]
417+
418+
419+
@v1.put(
420+
"/workspaces/{workspace_name}/muxes",
421+
tags=["Workspaces", "Muxes"],
422+
generate_unique_id_function=uniq_name,
423+
status_code=204,
424+
)
425+
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
426+
"""Set the mux rules of a workspace."""
427+
# TODO: This is a dummy implementation. In the future, we should have a proper
428+
# implementation that sets the mux rules in the database.
429+
return Response(status_code=204)
430+
431+
299432
@v1.get("/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name)
300433
async def stream_sse():
301434
"""

src/codegate/api/v1_models.py

+80
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,83 @@ class AlertConversation(pydantic.BaseModel):
138138
trigger_type: str
139139
trigger_category: Optional[str]
140140
timestamp: datetime.datetime
141+
142+
143+
class ProviderType(str, Enum):
144+
"""
145+
Represents the different types of providers we support.
146+
"""
147+
148+
openai = "openai"
149+
anthropic = "anthropic"
150+
vllm = "vllm"
151+
152+
153+
class ProviderAuthType(str, Enum):
154+
"""
155+
Represents the different types of auth we support for providers.
156+
"""
157+
158+
# No auth required
159+
none = "none"
160+
# Whatever the user provides is passed through
161+
passthrough = "passthrough"
162+
# API key is required
163+
api_key = "api_key"
164+
165+
166+
class ProviderEndpoint(pydantic.BaseModel):
167+
"""
168+
Represents a provider's endpoint configuration. This
169+
allows us to persist the configuration for each provider,
170+
so we can use this for muxing messages.
171+
"""
172+
173+
id: int
174+
name: str
175+
description: str = ""
176+
provider_type: ProviderType
177+
endpoint: str
178+
auth_type: ProviderAuthType
179+
180+
181+
class ModelByProvider(pydantic.BaseModel):
182+
"""
183+
Represents a model supported by a provider.
184+
185+
Note that these are auto-discovered by the provider.
186+
"""
187+
188+
name: str
189+
provider: str
190+
191+
def __str__(self):
192+
return f"{self.provider}/{self.name}"
193+
194+
195+
class MuxMatcherType(str, Enum):
196+
"""
197+
Represents the different types of matchers we support.
198+
"""
199+
200+
# Match a regular expression for a file path
201+
# in the prompt. Note that if no file is found,
202+
# the prompt will be passed through.
203+
file_regex = "file_regex"
204+
205+
# Always match this prompt
206+
catch_all = "catch_all"
207+
208+
209+
class MuxRule(pydantic.BaseModel):
210+
"""
211+
Represents a mux rule for a provider.
212+
"""
213+
214+
provider: str
215+
model: str
216+
# The type of matcher to use
217+
matcher_type: MuxMatcherType
218+
# The actual matcher to use. Note that
219+
# this depends on the matcher type.
220+
matcher: Optional[str]

0 commit comments

Comments
 (0)