Skip to content

Commit d242620

Browse files
authored
feat: refactor streaming route logic (#4369)
1 parent 0393084 commit d242620

File tree

1 file changed

+37
-69
lines changed

1 file changed

+37
-69
lines changed

letta/server/rest_api/routers/v1/agents.py

Lines changed: 37 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,9 @@ async def send_message_streaming(
12231223
request_start_timestamp_ns = get_utc_timestamp_ns()
12241224
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
12251225

1226+
# TODO (cliandy): clean this up
1227+
redis_client = await get_redis_client()
1228+
12261229
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
12271230
# TODO: This is redundant, remove soon
12281231
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
@@ -1263,14 +1266,11 @@ async def send_message_streaming(
12631266
),
12641267
actor=actor,
12651268
)
1269+
job_update_metadata = None
1270+
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
12661271
else:
12671272
run = None
12681273

1269-
job_update_metadata = None
1270-
# TODO (cliandy): clean this up
1271-
redis_client = await get_redis_client()
1272-
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
1273-
12741274
try:
12751275
if agent_eligible and model_compatible:
12761276
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
@@ -1308,6 +1308,23 @@ async def send_message_streaming(
13081308
),
13091309
)
13101310

1311+
if request.stream_tokens and model_compatible_token_streaming:
1312+
raw_stream = agent_loop.step_stream(
1313+
input_messages=request.messages,
1314+
max_steps=request.max_steps,
1315+
use_assistant_message=request.use_assistant_message,
1316+
request_start_timestamp_ns=request_start_timestamp_ns,
1317+
include_return_message_types=request.include_return_message_types,
1318+
)
1319+
else:
1320+
raw_stream = agent_loop.step_stream_no_tokens(
1321+
request.messages,
1322+
max_steps=request.max_steps,
1323+
use_assistant_message=request.use_assistant_message,
1324+
request_start_timestamp_ns=request_start_timestamp_ns,
1325+
include_return_message_types=request.include_return_message_types,
1326+
)
1327+
13111328
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
13121329

13131330
if request.background and settings.track_agent_run:
@@ -1321,23 +1338,6 @@ async def send_message_streaming(
13211338
),
13221339
)
13231340

1324-
if request.stream_tokens and model_compatible_token_streaming:
1325-
raw_stream = agent_loop.step_stream(
1326-
input_messages=request.messages,
1327-
max_steps=request.max_steps,
1328-
use_assistant_message=request.use_assistant_message,
1329-
request_start_timestamp_ns=request_start_timestamp_ns,
1330-
include_return_message_types=request.include_return_message_types,
1331-
)
1332-
else:
1333-
raw_stream = agent_loop.step_stream_no_tokens(
1334-
request.messages,
1335-
max_steps=request.max_steps,
1336-
use_assistant_message=request.use_assistant_message,
1337-
request_start_timestamp_ns=request_start_timestamp_ns,
1338-
include_return_message_types=request.include_return_message_types,
1339-
)
1340-
13411341
asyncio.create_task(
13421342
create_background_stream_processor(
13431343
stream_generator=raw_stream,
@@ -1346,55 +1346,21 @@ async def send_message_streaming(
13461346
)
13471347
)
13481348

1349-
stream = redis_sse_stream_generator(
1349+
raw_stream = redis_sse_stream_generator(
13501350
redis_client=redis_client,
13511351
run_id=run.id,
13521352
)
13531353

1354-
if request.include_pings and settings.enable_keepalive:
1355-
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
1356-
1357-
return StreamingResponseWithStatusCode(
1358-
stream,
1359-
media_type="text/event-stream",
1360-
)
1361-
1362-
if request.stream_tokens and model_compatible_token_streaming:
1363-
raw_stream = agent_loop.step_stream(
1364-
input_messages=request.messages,
1365-
max_steps=request.max_steps,
1366-
use_assistant_message=request.use_assistant_message,
1367-
request_start_timestamp_ns=request_start_timestamp_ns,
1368-
include_return_message_types=request.include_return_message_types,
1369-
)
1370-
# Conditionally wrap with keepalive based on request parameter
1371-
if request.include_pings and settings.enable_keepalive:
1372-
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
1373-
else:
1374-
stream = raw_stream
1375-
1376-
result = StreamingResponseWithStatusCode(
1377-
stream,
1378-
media_type="text/event-stream",
1379-
)
1354+
# Conditionally wrap with keepalive based on request parameter
1355+
if request.include_pings and settings.enable_keepalive:
1356+
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
13801357
else:
1381-
raw_stream = agent_loop.step_stream_no_tokens(
1382-
request.messages,
1383-
max_steps=request.max_steps,
1384-
use_assistant_message=request.use_assistant_message,
1385-
request_start_timestamp_ns=request_start_timestamp_ns,
1386-
include_return_message_types=request.include_return_message_types,
1387-
)
1388-
# Conditionally wrap with keepalive based on request parameter
1389-
if request.include_pings and settings.enable_keepalive:
1390-
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
1391-
else:
1392-
stream = raw_stream
1393-
1394-
result = StreamingResponseWithStatusCode(
1395-
stream,
1396-
media_type="text/event-stream",
1397-
)
1358+
stream = raw_stream
1359+
1360+
result = StreamingResponseWithStatusCode(
1361+
stream,
1362+
media_type="text/event-stream",
1363+
)
13981364
else:
13991365
result = await server.send_message_to_agent(
14001366
agent_id=agent_id,
@@ -1409,11 +1375,13 @@ async def send_message_streaming(
14091375
request_start_timestamp_ns=request_start_timestamp_ns,
14101376
include_return_message_types=request.include_return_message_types,
14111377
)
1412-
job_status = JobStatus.running
1378+
if settings.track_agent_run:
1379+
job_status = JobStatus.running
14131380
return result
14141381
except Exception as e:
1415-
job_update_metadata = {"error": str(e)}
1416-
job_status = JobStatus.failed
1382+
if settings.track_agent_run:
1383+
job_update_metadata = {"error": str(e)}
1384+
job_status = JobStatus.failed
14171385
raise
14181386
finally:
14191387
if settings.track_agent_run:

0 commit comments

Comments
 (0)