Skip to content

Commit 35bef8f

Browse files
authored
feat: convert compile system prompt to async (#3685)
1 parent 6977dfd commit 35bef8f

File tree

6 files changed

+201
-8
lines changed

6 files changed

+201
-8
lines changed

letta/agents/base_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from letta.schemas.usage import LettaUsageStatistics
1818
from letta.schemas.user import User
1919
from letta.services.agent_manager import AgentManager
20-
from letta.services.helpers.agent_manager_helper import compile_system_message
20+
from letta.services.helpers.agent_manager_helper import compile_system_message_async
2121
from letta.services.message_manager import MessageManager
2222
from letta.services.passage_manager import PassageManager
2323
from letta.utils import united_diff
@@ -142,7 +142,7 @@ def extract_dynamic_section(text):
142142
if num_archival_memories is None:
143143
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
144144

145-
new_system_message_str = compile_system_message(
145+
new_system_message_str = await compile_system_message_async(
146146
system_prompt=agent_state.system,
147147
in_context_memory=agent_state.memory,
148148
in_context_memory_last_edit=memory_edit_timestamp,

letta/agents/voice_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from letta.services.agent_manager import AgentManager
3838
from letta.services.block_manager import BlockManager
39-
from letta.services.helpers.agent_manager_helper import compile_system_message
39+
from letta.services.helpers.agent_manager_helper import compile_system_message_async
4040
from letta.services.job_manager import JobManager
4141
from letta.services.message_manager import MessageManager
4242
from letta.services.passage_manager import PassageManager
@@ -145,7 +145,7 @@ async def step_stream(self, input_messages: List[MessageCreate], max_steps: int
145145

146146
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
147147
memory_edit_timestamp = get_utc_time()
148-
in_context_messages[0].content[0].text = compile_system_message(
148+
in_context_messages[0].content[0].text = await compile_system_message_async(
149149
system_prompt=agent_state.system,
150150
in_context_memory=agent_state.memory,
151151
in_context_memory_last_edit=memory_edit_timestamp,

letta/schemas/memory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from openai.types.beta.function_tool import FunctionTool as OpenAITool
1212

1313
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
14+
from letta.otel.tracing import trace_method
1415
from letta.schemas.block import Block, FileBlock
1516
from letta.schemas.message import Message
1617

@@ -114,6 +115,7 @@ def get_prompt_template(self) -> str:
114115
"""Return the current Jinja2 template string."""
115116
return str(self.prompt_template)
116117

118+
@trace_method
117119
def set_prompt_template(self, prompt_template: str):
118120
"""
119121
Set a new Jinja2 template string.
@@ -133,6 +135,7 @@ def set_prompt_template(self, prompt_template: str):
133135
except Exception as e:
134136
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
135137

138+
@trace_method
136139
async def set_prompt_template_async(self, prompt_template: str):
137140
"""
138141
Async version of set_prompt_template that doesn't block the event loop.
@@ -152,6 +155,7 @@ async def set_prompt_template_async(self, prompt_template: str):
152155
except Exception as e:
153156
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
154157

158+
@trace_method
155159
def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
156160
"""Generate a string representation of the memory in-context using the Jinja2 template"""
157161
try:
@@ -168,6 +172,7 @@ def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> s
168172
except Exception as e:
169173
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
170174

175+
@trace_method
171176
async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
172177
"""Async version of compile that doesn't block the event loop"""
173178
try:

letta/services/agent_manager.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@
8686
calculate_multi_agent_tools,
8787
check_supports_structured_output,
8888
compile_system_message,
89+
compile_system_message_async,
8990
derive_system_message,
9091
initialize_message_sequence,
92+
initialize_message_sequence_async,
9193
package_initial_message_sequence,
9294
validate_agent_exists_async,
9395
)
@@ -621,7 +623,7 @@ async def create_agent_async(
621623

622624
# initial message sequence (skip if _init_with_no_messages is True)
623625
if not _init_with_no_messages:
624-
init_messages = self._generate_initial_message_sequence(
626+
init_messages = await self._generate_initial_message_sequence_async(
625627
actor,
626628
agent_state=result,
627629
supplied_initial_message_sequence=agent_create.initial_message_sequence,
@@ -666,6 +668,35 @@ def _generate_initial_message_sequence(
666668

667669
return init_messages
668670

671+
@enforce_types
672+
async def _generate_initial_message_sequence_async(
673+
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
674+
) -> List[Message]:
675+
init_messages = await initialize_message_sequence_async(
676+
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
677+
)
678+
if supplied_initial_message_sequence is not None:
679+
# We always need the system prompt up front
680+
system_message_obj = PydanticMessage.dict_to_message(
681+
agent_id=agent_state.id,
682+
model=agent_state.llm_config.model,
683+
openai_message_dict=init_messages[0],
684+
)
685+
# Don't use anything else in the pregen sequence, instead use the provided sequence
686+
init_messages = [system_message_obj]
687+
init_messages.extend(
688+
package_initial_message_sequence(
689+
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
690+
)
691+
)
692+
else:
693+
init_messages = [
694+
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
695+
for msg in init_messages
696+
]
697+
698+
return init_messages
699+
669700
@enforce_types
670701
@trace_method
671702
def append_initial_message_sequence_to_in_context_messages(
@@ -679,7 +710,7 @@ def append_initial_message_sequence_to_in_context_messages(
679710
async def append_initial_message_sequence_to_in_context_messages_async(
680711
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
681712
) -> PydanticAgentState:
682-
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
713+
init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence)
683714
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
684715

685716
@enforce_types
@@ -1674,7 +1705,7 @@ async def rebuild_system_prompt_async(
16741705

16751706
# update memory (TODO: potentially update recall/archival stats separately)
16761707

1677-
new_system_message_str = compile_system_message(
1708+
new_system_message_str = await compile_system_message_async(
16781709
system_prompt=agent_state.system,
16791710
in_context_memory=agent_state.memory,
16801711
in_context_memory_last_edit=memory_edit_timestamp,
@@ -1809,7 +1840,7 @@ async def reset_messages_async(
18091840

18101841
# Optionally add default initial messages after the system message
18111842
if add_default_initial_messages:
1812-
init_messages = initialize_message_sequence(
1843+
init_messages = await initialize_message_sequence_async(
18131844
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
18141845
)
18151846
# Skip index 0 (system message) since we preserved the original

letta/services/helpers/agent_manager_helper.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def safe_format(template: str, variables: dict) -> str:
248248
return escaped.format_map(PreserveMapping(variables))
249249

250250

251+
@trace_method
251252
def compile_system_message(
252253
system_prompt: str,
253254
in_context_memory: Memory,
@@ -327,6 +328,87 @@ def compile_system_message(
327328
return formatted_prompt
328329

329330

331+
@trace_method
332+
async def compile_system_message_async(
333+
system_prompt: str,
334+
in_context_memory: Memory,
335+
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
336+
timezone: str,
337+
user_defined_variables: Optional[dict] = None,
338+
append_icm_if_missing: bool = True,
339+
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
340+
previous_message_count: int = 0,
341+
archival_memory_size: int = 0,
342+
tool_rules_solver: Optional[ToolRulesSolver] = None,
343+
sources: Optional[List] = None,
344+
max_files_open: Optional[int] = None,
345+
) -> str:
346+
"""Prepare the final/full system message that will be fed into the LLM API
347+
348+
The base system message may be templated, in which case we need to render the variables.
349+
350+
The following are reserved variables:
351+
- CORE_MEMORY: the in-context memory of the LLM
352+
"""
353+
354+
# Add tool rule constraints if available
355+
tool_constraint_block = None
356+
if tool_rules_solver is not None:
357+
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
358+
359+
if user_defined_variables is not None:
360+
# TODO eventually support the user defining their own variables to inject
361+
raise NotImplementedError
362+
else:
363+
variables = {}
364+
365+
# Add the protected memory variable
366+
if IN_CONTEXT_MEMORY_KEYWORD in variables:
367+
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
368+
else:
369+
# TODO should this all put into the memory.__repr__ function?
370+
memory_metadata_string = compile_memory_metadata_block(
371+
memory_edit_timestamp=in_context_memory_last_edit,
372+
previous_message_count=previous_message_count,
373+
archival_memory_size=archival_memory_size,
374+
timezone=timezone,
375+
)
376+
377+
memory_with_sources = await in_context_memory.compile_async(
378+
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
379+
)
380+
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
381+
382+
# Add to the variables list to inject
383+
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
384+
385+
if template_format == "f-string":
386+
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
387+
388+
# Catch the special case where the system prompt is unformatted
389+
if append_icm_if_missing:
390+
if memory_variable_string not in system_prompt:
391+
# In this case, append it to the end to make sure memory is still injected
392+
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
393+
system_prompt += "\n\n" + memory_variable_string
394+
395+
# render the variables using the built-in templater
396+
try:
397+
if user_defined_variables:
398+
formatted_prompt = safe_format(system_prompt, variables)
399+
else:
400+
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
401+
except Exception as e:
402+
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
403+
404+
else:
405+
# TODO support for mustache and jinja2
406+
raise NotImplementedError(template_format)
407+
408+
return formatted_prompt
409+
410+
411+
@trace_method
330412
def initialize_message_sequence(
331413
agent_state: AgentState,
332414
memory_edit_timestamp: Optional[datetime] = None,
@@ -396,6 +478,76 @@ def initialize_message_sequence(
396478
return messages
397479

398480

481+
@trace_method
482+
async def initialize_message_sequence_async(
483+
agent_state: AgentState,
484+
memory_edit_timestamp: Optional[datetime] = None,
485+
include_initial_boot_message: bool = True,
486+
previous_message_count: int = 0,
487+
archival_memory_size: int = 0,
488+
) -> List[dict]:
489+
if memory_edit_timestamp is None:
490+
memory_edit_timestamp = get_local_time()
491+
492+
full_system_message = await compile_system_message_async(
493+
system_prompt=agent_state.system,
494+
in_context_memory=agent_state.memory,
495+
in_context_memory_last_edit=memory_edit_timestamp,
496+
timezone=agent_state.timezone,
497+
user_defined_variables=None,
498+
append_icm_if_missing=True,
499+
previous_message_count=previous_message_count,
500+
archival_memory_size=archival_memory_size,
501+
sources=agent_state.sources,
502+
max_files_open=agent_state.max_files_open,
503+
)
504+
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
505+
506+
if include_initial_boot_message:
507+
llm_config = agent_state.llm_config
508+
uuid_str = str(uuid.uuid4())
509+
510+
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
511+
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
512+
513+
if agent_state.agent_type == AgentType.sleeptime_agent:
514+
initial_boot_messages = []
515+
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
516+
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
517+
else:
518+
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
519+
520+
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
521+
if llm_config.provider_name == "lmstudio_openai":
522+
messages = (
523+
[
524+
{"role": "system", "content": full_system_message},
525+
]
526+
+ [
527+
{"role": "user", "content": first_user_message},
528+
]
529+
+ initial_boot_messages
530+
)
531+
else:
532+
messages = (
533+
[
534+
{"role": "system", "content": full_system_message},
535+
]
536+
+ initial_boot_messages
537+
+ [
538+
{"role": "user", "content": first_user_message},
539+
]
540+
)
541+
542+
else:
543+
messages = [
544+
{"role": "system", "content": full_system_message},
545+
{"role": "user", "content": first_user_message},
546+
]
547+
548+
return messages
549+
550+
399551
def package_initial_message_sequence(
400552
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, timezone: str, actor: User
401553
) -> List[Message]:

letta/templates/template_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
44

5+
from letta.otel.tracing import trace_method
6+
57
TEMPLATE_DIR = os.path.dirname(__file__)
68

79
# Synchronous environment (for backward compatibility)
@@ -22,18 +24,21 @@
2224
)
2325

2426

27+
@trace_method
2528
def render_template(template_name: str, **kwargs):
2629
"""Synchronous template rendering function (kept for backward compatibility)"""
2730
template = jinja_env.get_template(template_name)
2831
return template.render(**kwargs)
2932

3033

34+
@trace_method
3135
async def render_template_async(template_name: str, **kwargs):
3236
"""Asynchronous template rendering function that doesn't block the event loop"""
3337
template = jinja_async_env.get_template(template_name)
3438
return await template.render_async(**kwargs)
3539

3640

41+
@trace_method
3742
async def render_string_async(template_string: str, **kwargs):
3843
"""Asynchronously render a template from a string"""
3944
template = Template(template_string, enable_async=True)

0 commit comments

Comments
 (0)