Skip to content

Commit e268c73

Browse files
committed
LLMAssistantAggregator: cache function call requested images
1 parent d3c57e2 commit e268c73

1 file changed

Lines changed: 51 additions & 23 deletions

File tree

src/pipecat/processors/aggregators/llm_response_universal.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ def __init__(
641641

642642
self._started = 0
643643
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
644+
self._function_calls_image_results: Dict[str, UserImageRawFrame] = {}
644645
self._context_updated_tasks: Set[asyncio.Task] = set()
645646

646647
self._assistant_turn_start_timestamp = ""
@@ -820,6 +821,15 @@ async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
820821

821822
run_llm = False
822823

824+
# Append any images that were generated by function calls.
825+
if frame.tool_call_id in self._function_calls_image_results:
826+
image_frame = self._function_calls_image_results[frame.tool_call_id]
827+
828+
del self._function_calls_image_results[frame.tool_call_id]
829+
830+
# If an image frame has been added to the context, let's run inference.
831+
run_llm = await self._maybe_append_image_to_context(image_frame)
832+
823833
# Run inference if the function call result requires it.
824834
if frame.result:
825835
if properties and properties.run_llm is not None:
@@ -856,31 +866,24 @@ async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
856866
self._update_function_call_result(frame.function_name, frame.tool_call_id, "CANCELLED")
857867
del self._function_calls_in_progress[frame.tool_call_id]
858868

859-
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
860-
for message in self._context.get_messages():
861-
if (
862-
not isinstance(message, LLMSpecificMessage)
863-
and message["role"] == "tool"
864-
and message["tool_call_id"]
865-
and message["tool_call_id"] == tool_call_id
866-
):
867-
message["content"] = result
868-
869869
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
870-
if not frame.append_to_context:
871-
return
872-
873-
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
874-
875-
await self._context.add_image_frame_message(
876-
format=frame.format,
877-
size=frame.size,
878-
image=frame.image,
879-
text=frame.text,
880-
)
870+
image_appended = False
871+
872+
# Check if this image is a result of a function call if so, let's cache.
873+
# TODO(aleix): The function call might have already been executed
874+
# because FunctionCallResultFrame was just faster, in that case we just
875+
# push the context frame now.
876+
if (
877+
frame.request
878+
and frame.request.tool_call_id
879+
and frame.request.tool_call_id in self._function_calls_in_progress
880+
):
881+
self._function_calls_image_results[frame.request.tool_call_id] = frame
882+
else:
883+
image_appended = await self._maybe_append_image_to_context(frame)
881884

882-
await self._trigger_assistant_turn_stopped()
883-
await self.push_context_frame(FrameDirection.UPSTREAM)
885+
if image_appended:
886+
await self.push_context_frame(FrameDirection.UPSTREAM)
884887

885888
async def _handle_assistant_image_frame(self, frame: AssistantImageRawFrame):
886889
logger.debug(f"{self} Appending AssistantImageRawFrame to LLM context (size: {frame.size})")
@@ -970,6 +973,31 @@ async def _handle_thought_end(self, frame: LLMThoughtEndFrame):
970973

971974
await self._call_event_handler("on_assistant_thought", message)
972975

976+
async def _maybe_append_image_to_context(self, frame: UserImageRawFrame) -> bool:
977+
if not frame.append_to_context:
978+
return False
979+
980+
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
981+
982+
await self._context.add_image_frame_message(
983+
format=frame.format,
984+
size=frame.size,
985+
image=frame.image,
986+
text=frame.text,
987+
)
988+
989+
return True
990+
991+
def _update_function_call_result(self, function_name: str, tool_call_id: str, result: Any):
992+
for message in self._context.get_messages():
993+
if (
994+
not isinstance(message, LLMSpecificMessage)
995+
and message["role"] == "tool"
996+
and message["tool_call_id"]
997+
and message["tool_call_id"] == tool_call_id
998+
):
999+
message["content"] = result
1000+
9731001
def _context_updated_task_finished(self, task: asyncio.Task):
9741002
self._context_updated_tasks.discard(task)
9751003

0 commit comments

Comments
 (0)