Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit b46c5e3

Browse files
authored
Update models on codegate initialization (#1027)
As part of codegate's boot process, this now refreshes the provider's models. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent e6600f6 commit b46c5e3

File tree

1 file changed

+47
-9
lines changed

1 file changed

+47
-9
lines changed

src/codegate/providers/crud/crud.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ async def update_endpoint(
154154
authm = await self._db_reader.get_auth_material_by_provider_id(str(endpoint.id))
155155

156156
models = await self._find_models_for_provider(
157-
endpoint, authm.auth_type, authm.auth_blob, prov
157+
endpoint.endpoint, authm.auth_type, authm.auth_blob, prov
158158
)
159159

160-
await self._update_models_for_provider(dbendpoint, endpoint, prov, models)
160+
await self._update_models_for_provider(dbendpoint, models)
161161

162162
# a model might have been deleted, let's repopulate the cache
163163
await self._ws_crud.repopulate_mux_cache()
@@ -191,7 +191,7 @@ async def configure_auth_material(
191191
prov = endpoint.get_from_registry(provider_registry)
192192

193193
models = await self._find_models_for_provider(
194-
endpoint, config.auth_type, config.api_key, prov
194+
endpoint.endpoint, config.auth_type, config.api_key, prov
195195
)
196196

197197
await self._db_writer.push_provider_auth_material(
@@ -202,35 +202,34 @@ async def configure_auth_material(
202202
)
203203
)
204204

205-
await self._update_models_for_provider(dbendpoint, endpoint, models)
205+
await self._update_models_for_provider(dbendpoint, models)
206206

207207
# a model might have been deleted, let's repopulate the cache
208208
await self._ws_crud.repopulate_mux_cache()
209209

210210
async def _find_models_for_provider(
211211
self,
212-
endpoint: apimodelsv1.ProviderEndpoint,
212+
endpoint: str,
213213
auth_type: apimodelsv1.ProviderAuthType,
214214
api_key: str,
215215
prov: BaseProvider,
216216
) -> List[str]:
217217
if auth_type != apimodelsv1.ProviderAuthType.passthrough:
218218
try:
219-
return prov.models(endpoint=endpoint.endpoint, api_key=api_key)
219+
return prov.models(endpoint=endpoint, api_key=api_key)
220220
except Exception as err:
221221
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")
222222
return []
223223

224224
async def _update_models_for_provider(
225225
self,
226226
dbendpoint: dbmodels.ProviderEndpoint,
227-
endpoint: apimodelsv1.ProviderEndpoint,
228227
found_models: List[str],
229228
) -> None:
230229
models_set = set(found_models)
231230

232231
# Get the models from the provider
233-
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))
232+
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(dbendpoint.id))
234233

235234
models_in_db_set = set(model.name for model in models_in_db)
236235

@@ -318,7 +317,7 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
318317
dbprovend = await db_reader.get_provider_endpoint_by_name(provend.name)
319318
if dbprovend is not None:
320319
logger.debug(
321-
"Provider already in DB. Not re-adding.",
320+
"Provider already in DB. skipping",
322321
provider=provend.name,
323322
endpoint=provend.endpoint,
324323
)
@@ -334,6 +333,21 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
334333
continue
335334
await try_initialize_provider_endpoints(provend, pimpl, db_writer)
336335

336+
provcrud = ProviderCrud()
337+
338+
endpoints = await provcrud.list_endpoints()
339+
for endpoint in endpoints:
340+
dbprovend = await db_reader.get_provider_endpoint_by_name(endpoint.name)
341+
pimpl = endpoint.get_from_registry(preg)
342+
if pimpl is None:
343+
logger.warning(
344+
"Provider not found in registry",
345+
provider=endpoint.name,
346+
endpoint=endpoint.endpoint,
347+
)
348+
continue
349+
await try_update_to_provider(provcrud, pimpl, dbprovend)
350+
337351

338352
async def try_initialize_provider_endpoints(
339353
provend: apimodelsv1.ProviderEndpoint,
@@ -376,6 +390,30 @@ async def try_initialize_provider_endpoints(
376390
await asyncio.gather(*tasks)
377391

378392

393+
async def try_update_to_provider(
394+
provcrud: ProviderCrud, prov: BaseProvider, dbprovend: dbmodels.ProviderEndpoint
395+
):
396+
397+
authm = await provcrud._db_reader.get_auth_material_by_provider_id(str(dbprovend.id))
398+
399+
try:
400+
models = await provcrud._find_models_for_provider(
401+
dbprovend.endpoint, authm.auth_type, authm.auth_blob, prov
402+
)
403+
except Exception as err:
404+
logger.error(
405+
"Unable to get models from provider. Skipping",
406+
provider=dbprovend.name,
407+
err=str(err),
408+
)
409+
return
410+
411+
await provcrud._update_models_for_provider(dbprovend, models)
412+
413+
# a model might have been deleted, let's repopulate the cache
414+
await provcrud._ws_crud.repopulate_mux_cache()
415+
416+
379417
def __provider_endpoint_from_cfg(
380418
provider_name: str, provider_url: str
381419
) -> Optional[apimodelsv1.ProviderEndpoint]:

0 commit comments

Comments
 (0)