From 770e67df00015063f6b891191c768c074bb585d6 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 11 Feb 2025 14:37:04 +0200 Subject: [PATCH] Enable provider model updating when updating provider itself This now calls the model updating logic when updating the provider itself. Thus allowing us to have a way to update the model list. Signed-off-by: Juan Antonio Osorio --- src/codegate/providers/crud/crud.py | 65 +++++++++++++++++++---- src/codegate/providers/ollama/provider.py | 2 +- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 1b96fdbb..453af16f 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -144,6 +144,32 @@ async def update_endpoint( dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) + # If the auth type has not changed or no authentication is needed, + # we can update the models + if ( + founddbe.auth_type == endpoint.auth_type + or endpoint.auth_type == apimodelsv1.ProviderAuthType.none + ): + try: + authm = await self._db_reader.get_auth_material_by_provider_id(str(endpoint.id)) + + models = await self._find_models_for_provider( + endpoint, authm.auth_type, authm.auth_blob, prov + ) + + await self._update_models_for_provider(dbendpoint, endpoint, prov, models) + + # a model might have been deleted, let's repopulate the cache + await self._ws_crud.repopulate_mux_cache() + except Exception as err: + # This is a non-fatal error. The endpoint might have changed + # And the user will need to push a new API key anyway. + logger.error( + "Unable to update models for provider", + provider=endpoint.name, + err=str(err), + ) + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def configure_auth_material( @@ -164,12 +190,9 @@ async def configure_auth_material( provider_registry = get_provider_registry() prov = endpoint.get_from_registry(provider_registry) - models = [] - if config.auth_type != apimodelsv1.ProviderAuthType.passthrough: - try: - models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key) - except Exception as err: - raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}") + models = await self._find_models_for_provider( + endpoint, config.auth_type, config.api_key, prov + ) await self._db_writer.push_provider_auth_material( dbmodels.ProviderAuthMaterial( @@ -179,7 +202,32 @@ async def configure_auth_material( ) ) - models_set = set(models) + await self._update_models_for_provider(dbendpoint, endpoint, models) + + # a model might have been deleted, let's repopulate the cache + await self._ws_crud.repopulate_mux_cache() + + async def _find_models_for_provider( + self, + endpoint: apimodelsv1.ProviderEndpoint, + auth_type: apimodelsv1.ProviderAuthType, + api_key: str, + prov: BaseProvider, + ) -> List[str]: + if auth_type != apimodelsv1.ProviderAuthType.passthrough: + try: + return prov.models(endpoint=endpoint.endpoint, api_key=api_key) + except Exception as err: + raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}") + return [] + + async def _update_models_for_provider( + self, + dbendpoint: dbmodels.ProviderEndpoint, + endpoint: apimodelsv1.ProviderEndpoint, + found_models: List[str], + ) -> None: + models_set = set(found_models) # Get the models from the provider models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id)) @@ -202,9 +250,6 @@ async def configure_auth_material( model, ) - # a model might have been deleted, let's repopulate the cache - await self._ws_crud.repopulate_mux_cache() - async def delete_endpoint(self, provider_id: UUID): """Delete an endpoint.""" diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index c1e30909..c2963233 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import List import httpx import structlog