@@ -641,6 +641,7 @@ def __init__(
641641
642642 self ._started = 0
643643 self ._function_calls_in_progress : Dict [str , Optional [FunctionCallInProgressFrame ]] = {}
644+ self ._function_calls_image_results : Dict [str , UserImageRawFrame ] = {}
644645 self ._context_updated_tasks : Set [asyncio .Task ] = set ()
645646
646647 self ._assistant_turn_start_timestamp = ""
@@ -820,6 +821,15 @@ async def _handle_function_call_result(self, frame: FunctionCallResultFrame):
820821
821822 run_llm = False
822823
824+ # Append any images that were generated by function calls.
825+ if frame .tool_call_id in self ._function_calls_image_results :
826+ image_frame = self ._function_calls_image_results [frame .tool_call_id ]
827+
828+ del self ._function_calls_image_results [frame .tool_call_id ]
829+
830+ # If an image frame has been added to the context, let's run inference.
831+ run_llm = await self ._maybe_append_image_to_context (image_frame )
832+
823833 # Run inference if the function call result requires it.
824834 if frame .result :
825835 if properties and properties .run_llm is not None :
@@ -856,31 +866,24 @@ async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
856866 self ._update_function_call_result (frame .function_name , frame .tool_call_id , "CANCELLED" )
857867 del self ._function_calls_in_progress [frame .tool_call_id ]
858868
859- def _update_function_call_result (self , function_name : str , tool_call_id : str , result : Any ):
860- for message in self ._context .get_messages ():
861- if (
862- not isinstance (message , LLMSpecificMessage )
863- and message ["role" ] == "tool"
864- and message ["tool_call_id" ]
865- and message ["tool_call_id" ] == tool_call_id
866- ):
867- message ["content" ] = result
868-
869869 async def _handle_user_image_frame (self , frame : UserImageRawFrame ):
870- if not frame .append_to_context :
871- return
872-
873- logger .debug (f"{ self } Appending UserImageRawFrame to LLM context (size: { frame .size } )" )
874-
875- await self ._context .add_image_frame_message (
876- format = frame .format ,
877- size = frame .size ,
878- image = frame .image ,
879- text = frame .text ,
880- )
870+ image_appended = False
871+
872+ # Check if this image is a result of a function call if so, let's cache.
873+ # TODO(aleix): The function call might have already been executed
874+ # because FunctionCallResultFrame was just faster, in that case we just
875+ # push the context frame now.
876+ if (
877+ frame .request
878+ and frame .request .tool_call_id
879+ and frame .request .tool_call_id in self ._function_calls_in_progress
880+ ):
881+ self ._function_calls_image_results [frame .request .tool_call_id ] = frame
882+ else :
883+ image_appended = await self ._maybe_append_image_to_context (frame )
881884
882- await self . _trigger_assistant_turn_stopped ()
883- await self .push_context_frame (FrameDirection .UPSTREAM )
885+ if image_appended :
886+ await self .push_context_frame (FrameDirection .UPSTREAM )
884887
885888 async def _handle_assistant_image_frame (self , frame : AssistantImageRawFrame ):
886889 logger .debug (f"{ self } Appending AssistantImageRawFrame to LLM context (size: { frame .size } )" )
@@ -970,6 +973,31 @@ async def _handle_thought_end(self, frame: LLMThoughtEndFrame):
970973
971974 await self ._call_event_handler ("on_assistant_thought" , message )
972975
976+ async def _maybe_append_image_to_context (self , frame : UserImageRawFrame ) -> bool :
977+ if not frame .append_to_context :
978+ return False
979+
980+ logger .debug (f"{ self } Appending UserImageRawFrame to LLM context (size: { frame .size } )" )
981+
982+ await self ._context .add_image_frame_message (
983+ format = frame .format ,
984+ size = frame .size ,
985+ image = frame .image ,
986+ text = frame .text ,
987+ )
988+
989+ return True
990+
991+ def _update_function_call_result (self , function_name : str , tool_call_id : str , result : Any ):
992+ for message in self ._context .get_messages ():
993+ if (
994+ not isinstance (message , LLMSpecificMessage )
995+ and message ["role" ] == "tool"
996+ and message ["tool_call_id" ]
997+ and message ["tool_call_id" ] == tool_call_id
998+ ):
999+ message ["content" ] = result
1000+
9731001 def _context_updated_task_finished (self , task : asyncio .Task ):
9741002 self ._context_updated_tasks .discard (task )
9751003
0 commit comments