Skip to content

Commit e141a40

Browse files
authored
fix: all model types returned from ollama provider (#2744)
2 parents 0338f54 + af54f21 commit e141a40

File tree

2 files changed

+46
-55
lines changed

2 files changed

+46
-55
lines changed

letta/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
ADMIN_PREFIX = "/v1/admin"
1212
API_PREFIX = "/v1"
13+
OLLAMA_API_PREFIX = "/v1"
1314
OPENAI_API_PREFIX = "/openai"
1415

1516
COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY"
@@ -50,8 +51,9 @@
5051
# Max steps for agent loop
5152
DEFAULT_MAX_STEPS = 50
5253

53-
# minimum context window size
54+
# context window size
5455
MIN_CONTEXT_WINDOW = 4096
56+
DEFAULT_CONTEXT_WINDOW = 32000
5557

5658
# number of concurrent embedding requests to sent
5759
EMBEDDING_BATCH_SIZE = 200
@@ -63,6 +65,7 @@
6365
# embeddings
6466
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
6567
DEFAULT_EMBEDDING_CHUNK_SIZE = 300
68+
DEFAULT_EMBEDDING_DIM = 1024
6669

6770
# tokenizers
6871
EMBEDDING_TO_TOKENIZER_MAP = {

letta/schemas/providers/ollama.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import aiohttp
44
from pydantic import Field
55

6-
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
6+
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, DEFAULT_CONTEXT_WINDOW, DEFAULT_EMBEDDING_DIM, OLLAMA_API_PREFIX
77
from letta.log import get_logger
88
from letta.schemas.embedding_config import EmbeddingConfig
99
from letta.schemas.enums import ProviderCategory, ProviderType
@@ -12,8 +12,6 @@
1212

1313
logger = get_logger(__name__)
1414

15-
ollama_prefix = "/v1"
16-
1715

1816
class OllamaProvider(OpenAIProvider):
1917
"""Ollama provider that uses the native /api/generate endpoint
@@ -41,19 +39,30 @@ async def list_llm_models_async(self) -> list[LLMConfig]:
4139
response_json = await response.json()
4240

4341
configs = []
44-
for model in response_json["models"]:
45-
context_window = await self._get_model_context_window(model["name"])
42+
for model in response_json.get("models", []):
43+
model_name = model["name"]
44+
model_details = await self._get_model_details_async(model_name)
45+
if not model_details or "completion" not in model_details.get("capabilities", []):
46+
continue
47+
48+
context_window = None
49+
model_info = model_details.get("model_info", {})
50+
if architecture := model_info.get("general.architecture"):
51+
if context_length := model_info.get(f"{architecture}.context_length"):
52+
context_window = int(context_length)
53+
4654
if context_window is None:
47-
print(f"Ollama model {model['name']} has no context window, using default 32000")
48-
context_window = 32000
55+
logger.warning(f"Ollama model {model_name} has no context window, using default {DEFAULT_CONTEXT_WINDOW}")
56+
context_window = DEFAULT_CONTEXT_WINDOW
57+
4958
configs.append(
5059
LLMConfig(
51-
model=model["name"],
60+
model=model_name,
5261
model_endpoint_type=ProviderType.ollama,
53-
model_endpoint=f"{self.base_url}{ollama_prefix}",
62+
model_endpoint=f"{self.base_url}{OLLAMA_API_PREFIX}",
5463
model_wrapper=self.default_prompt_formatter,
5564
context_window=context_window,
56-
handle=self.get_handle(model["name"]),
65+
handle=self.get_handle(model_name),
5766
provider_name=self.name,
5867
provider_category=self.provider_category,
5968
)
@@ -73,25 +82,36 @@ async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
7382
response_json = await response.json()
7483

7584
configs = []
76-
for model in response_json["models"]:
77-
embedding_dim = await self._get_model_embedding_dim(model["name"])
85+
for model in response_json.get("models", []):
86+
model_name = model["name"]
87+
model_details = await self._get_model_details_async(model_name)
88+
if not model_details or "embedding" not in model_details.get("capabilities", []):
89+
continue
90+
91+
embedding_dim = None
92+
model_info = model_details.get("model_info", {})
93+
if architecture := model_info.get("general.architecture"):
94+
if embedding_length := model_info.get(f"{architecture}.embedding_length"):
95+
embedding_dim = int(embedding_length)
96+
7897
if not embedding_dim:
79-
print(f"Ollama model {model['name']} has no embedding dimension, using default 1024")
80-
# continue
81-
embedding_dim = 1024
98+
logger.warning(f"Ollama model {model_name} has no embedding dimension, using default {DEFAULT_EMBEDDING_DIM}")
99+
embedding_dim = DEFAULT_EMBEDDING_DIM
100+
82101
configs.append(
83102
EmbeddingConfig(
84-
embedding_model=model["name"],
103+
embedding_model=model_name,
85104
embedding_endpoint_type=ProviderType.ollama,
86-
embedding_endpoint=f"{self.base_url}{ollama_prefix}",
105+
embedding_endpoint=f"{self.base_url}{OLLAMA_API_PREFIX}",
87106
embedding_dim=embedding_dim,
88107
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
89-
handle=self.get_handle(model["name"], is_embedding=True),
108+
handle=self.get_handle(model_name, is_embedding=True),
90109
)
91110
)
92111
return configs
93112

94-
async def _get_model_context_window(self, model_name: str) -> int | None:
113+
async def _get_model_details_async(self, model_name: str) -> dict | None:
114+
"""Get detailed information for a specific model from /api/show."""
95115
endpoint = f"{self.base_url}/api/show"
96116
payload = {"name": model_name}
97117

@@ -102,39 +122,7 @@ async def _get_model_context_window(self, model_name: str) -> int | None:
102122
error_text = await response.text()
103123
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
104124
return None
105-
106-
response_json = await response.json()
107-
model_info = response_json.get("model_info", {})
108-
109-
if architecture := model_info.get("general.architecture"):
110-
if context_length := model_info.get(f"{architecture}.context_length"):
111-
return int(context_length)
112-
125+
return await response.json()
113126
except Exception as e:
114-
logger.warning(f"Failed to get model context window for {model_name} with error: {e}")
115-
116-
return None
117-
118-
async def _get_model_embedding_dim(self, model_name: str) -> int | None:
119-
endpoint = f"{self.base_url}/api/show"
120-
payload = {"name": model_name}
121-
122-
try:
123-
async with aiohttp.ClientSession() as session:
124-
async with session.post(endpoint, json=payload) as response:
125-
if response.status != 200:
126-
error_text = await response.text()
127-
logger.warning(f"Failed to get model info for {model_name}: {response.status} - {error_text}")
128-
return None
129-
130-
response_json = await response.json()
131-
model_info = response_json.get("model_info", {})
132-
133-
if architecture := model_info.get("general.architecture"):
134-
if embedding_length := model_info.get(f"{architecture}.embedding_length"):
135-
return int(embedding_length)
136-
137-
except Exception as e:
138-
logger.warning(f"Failed to get model embedding dimension for {model_name} with error: {e}")
139-
140-
return None
127+
logger.warning(f"Failed to get model details for {model_name} with error: {e}")
128+
return None

0 commit comments

Comments
 (0)