33import aiohttp
44from pydantic import Field
55
6- from letta .constants import DEFAULT_EMBEDDING_CHUNK_SIZE
6+ from letta .constants import DEFAULT_EMBEDDING_CHUNK_SIZE , DEFAULT_CONTEXT_WINDOW , DEFAULT_EMBEDDING_DIM , OLLAMA_API_PREFIX
77from letta .log import get_logger
88from letta .schemas .embedding_config import EmbeddingConfig
99from letta .schemas .enums import ProviderCategory , ProviderType
1212
1313logger = get_logger (__name__ )
1414
15- ollama_prefix = "/v1"
16-
1715
1816class OllamaProvider (OpenAIProvider ):
1917 """Ollama provider that uses the native /api/generate endpoint
@@ -41,19 +39,30 @@ async def list_llm_models_async(self) -> list[LLMConfig]:
4139 response_json = await response .json ()
4240
4341 configs = []
44- for model in response_json ["models" ]:
45- context_window = await self ._get_model_context_window (model ["name" ])
42+ for model in response_json .get ("models" , []):
43+ model_name = model ["name" ]
44+ model_details = await self ._get_model_details_async (model_name )
45+ if not model_details or "completion" not in model_details .get ("capabilities" , []):
46+ continue
47+
48+ context_window = None
49+ model_info = model_details .get ("model_info" , {})
50+ if architecture := model_info .get ("general.architecture" ):
51+ if context_length := model_info .get (f"{ architecture } .context_length" ):
52+ context_window = int (context_length )
53+
4654 if context_window is None :
47- print (f"Ollama model { model ['name' ]} has no context window, using default 32000" )
48- context_window = 32000
55+ logger .warning (f"Ollama model { model_name } has no context window, using default { DEFAULT_CONTEXT_WINDOW } " )
56+ context_window = DEFAULT_CONTEXT_WINDOW
57+
4958 configs .append (
5059 LLMConfig (
51- model = model [ "name" ] ,
60+ model = model_name ,
5261 model_endpoint_type = ProviderType .ollama ,
53- model_endpoint = f"{ self .base_url } { ollama_prefix } " ,
62+ model_endpoint = f"{ self .base_url } { OLLAMA_API_PREFIX } " ,
5463 model_wrapper = self .default_prompt_formatter ,
5564 context_window = context_window ,
56- handle = self .get_handle (model [ "name" ] ),
65+ handle = self .get_handle (model_name ),
5766 provider_name = self .name ,
5867 provider_category = self .provider_category ,
5968 )
@@ -73,25 +82,36 @@ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
7382 response_json = await response .json ()
7483
7584 configs = []
76- for model in response_json ["models" ]:
77- embedding_dim = await self ._get_model_embedding_dim (model ["name" ])
85+ for model in response_json .get ("models" , []):
86+ model_name = model ["name" ]
87+ model_details = await self ._get_model_details_async (model_name )
88+ if not model_details or "embedding" not in model_details .get ("capabilities" , []):
89+ continue
90+
91+ embedding_dim = None
92+ model_info = model_details .get ("model_info" , {})
93+ if architecture := model_info .get ("general.architecture" ):
94+ if embedding_length := model_info .get (f"{ architecture } .embedding_length" ):
95+ embedding_dim = int (embedding_length )
96+
7897 if not embedding_dim :
79- print (f"Ollama model { model [ 'name' ] } has no embedding dimension, using default 1024 " )
80- # continue
81- embedding_dim = 1024
98+ logger . warning (f"Ollama model { model_name } has no embedding dimension, using default { DEFAULT_EMBEDDING_DIM } " )
99+ embedding_dim = DEFAULT_EMBEDDING_DIM
100+
82101 configs .append (
83102 EmbeddingConfig (
84- embedding_model = model [ "name" ] ,
103+ embedding_model = model_name ,
85104 embedding_endpoint_type = ProviderType .ollama ,
86- embedding_endpoint = f"{ self .base_url } { ollama_prefix } " ,
105+ embedding_endpoint = f"{ self .base_url } { OLLAMA_API_PREFIX } " ,
87106 embedding_dim = embedding_dim ,
88107 embedding_chunk_size = DEFAULT_EMBEDDING_CHUNK_SIZE ,
89- handle = self .get_handle (model [ "name" ] , is_embedding = True ),
108+ handle = self .get_handle (model_name , is_embedding = True ),
90109 )
91110 )
92111 return configs
93112
94- async def _get_model_context_window (self , model_name : str ) -> int | None :
113+ async def _get_model_details_async (self , model_name : str ) -> dict | None :
114+ """Get detailed information for a specific model from /api/show."""
95115 endpoint = f"{ self .base_url } /api/show"
96116 payload = {"name" : model_name }
97117
@@ -102,39 +122,7 @@ async def _get_model_context_window(self, model_name: str) -> int | None:
102122 error_text = await response .text ()
103123 logger .warning (f"Failed to get model info for { model_name } : { response .status } - { error_text } " )
104124 return None
105-
106- response_json = await response .json ()
107- model_info = response_json .get ("model_info" , {})
108-
109- if architecture := model_info .get ("general.architecture" ):
110- if context_length := model_info .get (f"{ architecture } .context_length" ):
111- return int (context_length )
112-
125+ return await response .json ()
113126 except Exception as e :
114- logger .warning (f"Failed to get model context window for { model_name } with error: { e } " )
115-
116- return None
117-
118- async def _get_model_embedding_dim (self , model_name : str ) -> int | None :
119- endpoint = f"{ self .base_url } /api/show"
120- payload = {"name" : model_name }
121-
122- try :
123- async with aiohttp .ClientSession () as session :
124- async with session .post (endpoint , json = payload ) as response :
125- if response .status != 200 :
126- error_text = await response .text ()
127- logger .warning (f"Failed to get model info for { model_name } : { response .status } - { error_text } " )
128- return None
129-
130- response_json = await response .json ()
131- model_info = response_json .get ("model_info" , {})
132-
133- if architecture := model_info .get ("general.architecture" ):
134- if embedding_length := model_info .get (f"{ architecture } .embedding_length" ):
135- return int (embedding_length )
136-
137- except Exception as e :
138- logger .warning (f"Failed to get model embedding dimension for { model_name } with error: { e } " )
139-
140- return None
127+ logger .warning (f"Failed to get model details for { model_name } with error: { e } " )
128+ return None
0 commit comments