diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 0e5e601e..4077b965 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -128,6 +128,27 @@ async def add_provider_endpoint( return provend +@v1.put( + "/provider-endpoints/{provider_id}/auth-material", + tags=["Providers"], + generate_unique_id_function=uniq_name, + status_code=204, +) +async def configure_auth_material( + provider_id: UUID, + request: v1_models.ConfigureAuthMaterial, +): + """Configure auth material for a provider.""" + try: + await pcrud.configure_auth_material(provider_id, request) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + return Response(status_code=204) + + @v1.put( "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name ) diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index e68e21f1..83e6a68a 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -223,7 +223,7 @@ class ProviderEndpoint(pydantic.BaseModel): description: str = "" provider_type: ProviderType endpoint: str - auth_type: ProviderAuthType + auth_type: Optional[ProviderAuthType] = ProviderAuthType.none @staticmethod def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint": @@ -250,6 +250,15 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider return registry.get_provider(self.provider_type) +class ConfigureAuthMaterial(pydantic.BaseModel): + """ + Represents a request to configure auth material for a provider. + """ + + auth_type: ProviderAuthType + api_key: Optional[str] = None + + class ModelByProvider(pydantic.BaseModel): """ Represents a model supported by a provider. diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index b821f4d6..caed5276 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -441,8 +441,10 @@ async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial) UPDATE provider_endpoints SET auth_type = :auth_type, auth_blob = :auth_blob WHERE id = :provider_endpoint_id + RETURNING id as provider_endpoint_id, auth_type, auth_blob """ ) + # Here we DONT want to return the result _ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True) return diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 637375e8..0e485df8 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -81,6 +81,27 @@ async def update_endpoint( dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + async def configure_auth_material( + self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial + ): + """Add an API key.""" + if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key: + raise ValueError("API key must be provided for API auth type") + elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key: + raise ValueError("API key provided for non-API auth type") + + dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id)) + if dbendpoint is None: + raise ProviderNotFoundError("Provider not found") + + await self._db_writer.push_provider_auth_material( + dbmodels.ProviderAuthMaterial( + provider_endpoint_id=dbendpoint.id, + auth_type=config.auth_type, + auth_blob=config.api_key if config.api_key else "", + ) + ) + async def delete_endpoint(self, provider_id: UUID): """Delete an endpoint."""