1- from typing import Dict
1+ from typing import Dict , Optional
22
33from marshmallow import fields , post_dump , pre_load
4+ from sqlalchemy import func
45from sqlalchemy .orm import sessionmaker
56
67import letta
1516from letta .serialize_schemas .marshmallow_message import SerializedMessageSchema
1617from letta .serialize_schemas .marshmallow_tag import SerializedAgentTagSchema
1718from letta .serialize_schemas .marshmallow_tool import SerializedToolSchema
19+ from letta .settings import DatabaseChoice , settings
1820
1921
2022class MarshmallowAgentSchema (BaseSchema ):
@@ -41,9 +43,10 @@ class MarshmallowAgentSchema(BaseSchema):
4143 tool_exec_environment_variables = fields .List (fields .Nested (SerializedAgentEnvironmentVariableSchema ))
4244 tags = fields .List (fields .Nested (SerializedAgentTagSchema ))
4345
44- def __init__ (self , * args , session : sessionmaker , actor : User , ** kwargs ):
46+ def __init__ (self , * args , session : sessionmaker , actor : User , max_steps : Optional [ int ] = None , ** kwargs ):
4547 super ().__init__ (* args , actor = actor , ** kwargs )
4648 self .session = session
49+ self .max_steps = max_steps
4750
4851 # Propagate session and actor to nested schemas automatically
4952 for field in self .fields .values ():
@@ -64,16 +67,103 @@ def attach_messages(self, data: Dict, **kwargs):
6467
6568 with db_registry .session () as session :
6669 agent_id = data .get ("id" )
67- msgs = (
68- session .query (MessageModel )
69- .filter (
70- MessageModel .agent_id == agent_id ,
71- MessageModel .organization_id == self .actor .organization_id ,
70+
71+ if self .max_steps is not None :
72+ # first, always get the system message
73+ system_msg = (
74+ session .query (MessageModel )
75+ .filter (
76+ MessageModel .agent_id == agent_id ,
77+ MessageModel .organization_id == self .actor .organization_id ,
78+ MessageModel .role == "system" ,
79+ )
80+ .order_by (MessageModel .sequence_id .asc ())
81+ .first ()
82+ )
83+
84+ if settings .database_engine is DatabaseChoice .POSTGRES :
85+ # efficient PostgreSQL approach using subquery
86+ user_msg_subquery = (
87+ session .query (MessageModel .sequence_id )
88+ .filter (
89+ MessageModel .agent_id == agent_id ,
90+ MessageModel .organization_id == self .actor .organization_id ,
91+ MessageModel .role == "user" ,
92+ )
93+ .order_by (MessageModel .sequence_id .desc ())
94+ .limit (self .max_steps )
95+ .subquery ()
96+ )
97+
98+ # get the minimum sequence_id from the subquery
99+ cutoff_sequence_id = session .query (func .min (user_msg_subquery .c .sequence_id )).scalar ()
100+
101+ if cutoff_sequence_id :
102+ # get messages from cutoff, excluding system message to avoid duplicates
103+ step_msgs = (
104+ session .query (MessageModel )
105+ .filter (
106+ MessageModel .agent_id == agent_id ,
107+ MessageModel .organization_id == self .actor .organization_id ,
108+ MessageModel .sequence_id >= cutoff_sequence_id ,
109+ MessageModel .role != "system" ,
110+ )
111+ .order_by (MessageModel .sequence_id .asc ())
112+ .all ()
113+ )
114+ # combine system message with step messages
115+ msgs = [system_msg ] + step_msgs if system_msg else step_msgs
116+ else :
117+ # no user messages, just return system message
118+ msgs = [system_msg ] if system_msg else []
119+ else :
120+ # sqlite approach: get all user messages first, then get messages from cutoff
121+ user_messages = (
122+ session .query (MessageModel .sequence_id )
123+ .filter (
124+ MessageModel .agent_id == agent_id ,
125+ MessageModel .organization_id == self .actor .organization_id ,
126+ MessageModel .role == "user" ,
127+ )
128+ .order_by (MessageModel .sequence_id .desc ())
129+ .limit (self .max_steps )
130+ .all ()
131+ )
132+
133+ if user_messages :
134+ # get the minimum sequence_id
135+ cutoff_sequence_id = min (msg .sequence_id for msg in user_messages )
136+
137+ # get messages from cutoff, excluding system message to avoid duplicates
138+ step_msgs = (
139+ session .query (MessageModel )
140+ .filter (
141+ MessageModel .agent_id == agent_id ,
142+ MessageModel .organization_id == self .actor .organization_id ,
143+ MessageModel .sequence_id >= cutoff_sequence_id ,
144+ MessageModel .role != "system" ,
145+ )
146+ .order_by (MessageModel .sequence_id .asc ())
147+ .all ()
148+ )
149+ # combine system message with step messages
150+ msgs = [system_msg ] + step_msgs if system_msg else step_msgs
151+ else :
152+ # no user messages, just return system message
153+ msgs = [system_msg ] if system_msg else []
154+ else :
155+ # if no limit, get all messages in ascending order
156+ msgs = (
157+ session .query (MessageModel )
158+ .filter (
159+ MessageModel .agent_id == agent_id ,
160+ MessageModel .organization_id == self .actor .organization_id ,
161+ )
162+ .order_by (MessageModel .sequence_id .asc ())
163+ .all ()
72164 )
73- .order_by (MessageModel .sequence_id .asc ())
74- .all ()
75- )
76- # overwrite the “messages” key with a fully serialized list
165+
166+ # overwrite the "messages" key with a fully serialized list
77167 data [self .FIELD_MESSAGES ] = [SerializedMessageSchema (session = self .session , actor = self .actor ).dump (m ) for m in msgs ]
78168
79169 return data
0 commit comments