Skip to content

Commit 8b1cfdb

Browse files
authored
feat: support overriding embedding handle [LET-4021] (#4224)
1 parent 9c6f25e commit 8b1cfdb

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

letta/server/rest_api/routers/v1/agents.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ async def import_agent(
252252
project_id: str | None = None,
253253
strip_messages: bool = False,
254254
env_vars: Optional[dict[str, Any]] = None,
255+
override_embedding_handle: Optional[str] = None,
255256
) -> List[str]:
256257
"""
257258
Import an agent using the new AgentFileSchema format.
@@ -262,12 +263,18 @@ async def import_agent(
262263
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
263264

264265
try:
266+
if override_embedding_handle:
267+
embedding_config_override = server.get_cached_embedding_config_async(actor=actor, handle=override_embedding_handle)
268+
else:
269+
embedding_config_override = None
270+
265271
import_result = await server.agent_serialization_manager.import_file(
266272
schema=agent_schema,
267273
actor=actor,
268274
append_copy_suffix=append_copy_suffix,
269275
override_existing_tools=override_existing_tools,
270276
env_vars=env_vars,
277+
override_embedding_config=embedding_config_override,
271278
)
272279

273280
if not import_result.success:
@@ -301,6 +308,10 @@ async def import_agent_serialized(
301308
True,
302309
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
303310
),
311+
override_embedding_handle: Optional[str] = Form(
312+
None,
313+
description="Override import with specific embedding handle.",
314+
),
304315
project_id: str | None = Form(None, description="The project ID to associate the uploaded agent with."),
305316
strip_messages: bool = Form(
306317
False,

letta/services/agent_serialization_manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ToolSchema,
2929
)
3030
from letta.schemas.block import Block
31+
from letta.schemas.embedding_config import EmbeddingConfig
3132
from letta.schemas.enums import FileProcessingStatus
3233
from letta.schemas.file import FileMetadata
3334
from letta.schemas.group import Group, GroupCreate
@@ -432,6 +433,7 @@ async def import_file(
432433
override_existing_tools: bool = True,
433434
dry_run: bool = False,
434435
env_vars: Optional[Dict[str, Any]] = None,
436+
override_embedding_config: Optional[EmbeddingConfig] = None,
435437
) -> ImportResult:
436438
"""
437439
Import AgentFileSchema into the database.
@@ -530,6 +532,12 @@ async def import_file(
530532
source_names_to_check = [s.name for s in schema.sources]
531533
existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor)
532534

535+
# override embedding_config
536+
if override_embedding_config:
537+
for source_schema in schema.sources:
538+
source_schema.embedding_config = override_embedding_config
539+
source_schema.embedding = override_embedding_config.handle
540+
533541
for source_schema in schema.sources:
534542
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
535543

@@ -577,10 +585,12 @@ async def import_file(
577585
# Start background tasks for file processing
578586
background_tasks = []
579587
if schema.files and any(f.content for f in schema.files):
588+
# Use override embedding config if provided, otherwise use agent's config
589+
embedder_config = override_embedding_config if override_embedding_config else schema.agents[0].embedding_config
580590
if should_use_pinecone():
581-
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
591+
embedder = PineconeEmbedder(embedding_config=embedder_config)
582592
else:
583-
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
593+
embedder = OpenAIEmbedder(embedding_config=embedder_config)
584594
file_processor = FileProcessor(
585595
file_parser=self.file_parser,
586596
embedder=embedder,
@@ -613,6 +623,11 @@ async def import_file(
613623

614624
# 6. Create agents with empty message history
615625
for agent_schema in schema.agents:
626+
# Override embedding_config if provided
627+
if override_embedding_config:
628+
agent_schema.embedding_config = override_embedding_config
629+
agent_schema.embedding = override_embedding_config.handle
630+
616631
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
617632
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
618633
if append_copy_suffix:

tests/test_agent_serialization_v2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ def test_agent(server: SyncServer, default_user, default_organization, test_bloc
205205
yield agent_state
206206

207207

208+
@pytest.fixture(scope="function")
209+
def embedding_handle_override():
210+
current_handle = EmbeddingConfig.default_config(provider="openai").handle
211+
assert current_handle != "letta/letta-free" # make sure its different
212+
return "letta/letta-free"
213+
214+
208215
@pytest.fixture(scope="function")
209216
async def test_source(server: SyncServer, default_user):
210217
"""Fixture to create and return a test source."""
@@ -1063,6 +1070,28 @@ async def test_basic_import(self, agent_serialization_manager, test_agent, defau
10631070
if file_id.startswith("agent-"):
10641071
assert db_id != test_agent.id # New agent should have different ID
10651072

1073+
async def test_basic_import_with_embedding_override(
1074+
self, server, agent_serialization_manager, test_agent, default_user, other_user, embedding_handle_override
1075+
):
1076+
"""Test basic agent import functionality with embedding override."""
1077+
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
1078+
1079+
embedding_config_override = await server.get_cached_embedding_config_async(actor=other_user, handle=embedding_handle_override)
1080+
result = await agent_serialization_manager.import_file(agent_file, other_user, override_embedding_config=embedding_config_override)
1081+
1082+
assert result.success
1083+
assert result.imported_count > 0
1084+
assert len(result.id_mappings) > 0
1085+
1086+
for file_id, db_id in result.id_mappings.items():
1087+
if file_id.startswith("agent-"):
1088+
assert db_id != test_agent.id # New agent should have different ID
1089+
1090+
# check embedding handle
1091+
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
1092+
imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user)
1093+
assert imported_agent.embedding_config.handle == embedding_handle_override
1094+
10661095
async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user):
10671096
"""Test that import preserves all important data."""
10681097
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)

0 commit comments

Comments
 (0)