@@ -154,10 +154,10 @@ async def update_endpoint(
154
154
authm = await self ._db_reader .get_auth_material_by_provider_id (str (endpoint .id ))
155
155
156
156
models = await self ._find_models_for_provider (
157
- endpoint , authm .auth_type , authm .auth_blob , prov
157
+ endpoint . endpoint , authm .auth_type , authm .auth_blob , prov
158
158
)
159
159
160
- await self ._update_models_for_provider (dbendpoint , endpoint , prov , models )
160
+ await self ._update_models_for_provider (dbendpoint , models )
161
161
162
162
# a model might have been deleted, let's repopulate the cache
163
163
await self ._ws_crud .repopulate_mux_cache ()
@@ -191,7 +191,7 @@ async def configure_auth_material(
191
191
prov = endpoint .get_from_registry (provider_registry )
192
192
193
193
models = await self ._find_models_for_provider (
194
- endpoint , config .auth_type , config .api_key , prov
194
+ endpoint . endpoint , config .auth_type , config .api_key , prov
195
195
)
196
196
197
197
await self ._db_writer .push_provider_auth_material (
@@ -202,35 +202,34 @@ async def configure_auth_material(
202
202
)
203
203
)
204
204
205
- await self ._update_models_for_provider (dbendpoint , endpoint , models )
205
+ await self ._update_models_for_provider (dbendpoint , models )
206
206
207
207
# a model might have been deleted, let's repopulate the cache
208
208
await self ._ws_crud .repopulate_mux_cache ()
209
209
210
210
async def _find_models_for_provider (
211
211
self ,
212
- endpoint : apimodelsv1 . ProviderEndpoint ,
212
+ endpoint : str ,
213
213
auth_type : apimodelsv1 .ProviderAuthType ,
214
214
api_key : str ,
215
215
prov : BaseProvider ,
216
216
) -> List [str ]:
217
217
if auth_type != apimodelsv1 .ProviderAuthType .passthrough :
218
218
try :
219
- return prov .models (endpoint = endpoint . endpoint , api_key = api_key )
219
+ return prov .models (endpoint = endpoint , api_key = api_key )
220
220
except Exception as err :
221
221
raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
222
222
return []
223
223
224
224
async def _update_models_for_provider (
225
225
self ,
226
226
dbendpoint : dbmodels .ProviderEndpoint ,
227
- endpoint : apimodelsv1 .ProviderEndpoint ,
228
227
found_models : List [str ],
229
228
) -> None :
230
229
models_set = set (found_models )
231
230
232
231
# Get the models from the provider
233
- models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (endpoint .id ))
232
+ models_in_db = await self ._db_reader .get_provider_models_by_provider_id (str (dbendpoint .id ))
234
233
235
234
models_in_db_set = set (model .name for model in models_in_db )
236
235
@@ -318,7 +317,7 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
318
317
dbprovend = await db_reader .get_provider_endpoint_by_name (provend .name )
319
318
if dbprovend is not None :
320
319
logger .debug (
321
- "Provider already in DB. Not re-adding. " ,
320
+ "Provider already in DB. skipping " ,
322
321
provider = provend .name ,
323
322
endpoint = provend .endpoint ,
324
323
)
@@ -334,6 +333,21 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
334
333
continue
335
334
await try_initialize_provider_endpoints (provend , pimpl , db_writer )
336
335
336
+ provcrud = ProviderCrud ()
337
+
338
+ endpoints = await provcrud .list_endpoints ()
339
+ for endpoint in endpoints :
340
+ dbprovend = await db_reader .get_provider_endpoint_by_name (endpoint .name )
341
+ pimpl = endpoint .get_from_registry (preg )
342
+ if pimpl is None :
343
+ logger .warning (
344
+ "Provider not found in registry" ,
345
+ provider = endpoint .name ,
346
+ endpoint = endpoint .endpoint ,
347
+ )
348
+ continue
349
+ await try_update_to_provider (provcrud , pimpl , dbprovend )
350
+
337
351
338
352
async def try_initialize_provider_endpoints (
339
353
provend : apimodelsv1 .ProviderEndpoint ,
@@ -376,6 +390,30 @@ async def try_initialize_provider_endpoints(
376
390
await asyncio .gather (* tasks )
377
391
378
392
393
+ async def try_update_to_provider (
394
+ provcrud : ProviderCrud , prov : BaseProvider , dbprovend : dbmodels .ProviderEndpoint
395
+ ):
396
+
397
+ authm = await provcrud ._db_reader .get_auth_material_by_provider_id (str (dbprovend .id ))
398
+
399
+ try :
400
+ models = await provcrud ._find_models_for_provider (
401
+ dbprovend .endpoint , authm .auth_type , authm .auth_blob , prov
402
+ )
403
+ except Exception as err :
404
+ logger .error (
405
+ "Unable to get models from provider. Skipping" ,
406
+ provider = dbprovend .name ,
407
+ err = str (err ),
408
+ )
409
+ return
410
+
411
+ await provcrud ._update_models_for_provider (dbprovend , models )
412
+
413
+ # a model might have been deleted, let's repopulate the cache
414
+ await provcrud ._ws_crud .repopulate_mux_cache ()
415
+
416
+
379
417
def __provider_endpoint_from_cfg (
380
418
provider_name : str , provider_url : str
381
419
) -> Optional [apimodelsv1 .ProviderEndpoint ]:
0 commit comments