|
22 | 22 | """ |
23 | 23 |
|
24 | 24 | import asyncio |
25 | | -from dataclasses import dataclass |
26 | 25 | import inspect |
| 26 | +from dataclasses import dataclass |
27 | 27 | from typing import Callable, Dict, List, Optional |
28 | 28 |
|
29 | 29 | from loguru import logger |
30 | 30 | from pipecat.frames.frames import ( |
| 31 | + BotStoppedSpeakingFrame, |
| 32 | + ControlFrame, |
31 | 33 | EndFrame, |
32 | 34 | TTSSpeakFrame, |
33 | 35 | ) |
34 | 36 | from pipecat.pipeline.task import PipelineTask |
35 | | -from pipecat.frames.frames import ControlFrame |
36 | 37 |
|
37 | 38 | from .exceptions import ActionError |
38 | 39 | from .types import ActionConfig, FlowActionHandler |
@@ -73,19 +74,23 @@ def __init__(self, task: PipelineTask, flow_manager: "FlowManager", tts=None): |
73 | 74 | self._flow_manager = flow_manager |
74 | 75 | self.tts = tts |
75 | 76 | self.function_finished_event = asyncio.Event() |
| 77 | + self._deferred_post_actions: List[ActionConfig] = [] |
76 | 78 |
|
77 | 79 | # Register built-in actions |
78 | 80 | self._register_action("tts_say", self._handle_tts_action) |
79 | 81 | self._register_action("end_conversation", self._handle_end_action) |
80 | 82 | self._register_action("function", self._handle_function_action) |
81 | 83 |
|
82 | 84 | # Wire up function actions |
83 | | - task.set_reached_downstream_filter((FunctionActionFrame,)) |
| 85 | + task.set_reached_downstream_filter((FunctionActionFrame, BotStoppedSpeakingFrame)) |
| 86 | + |
84 | 87 | @task.event_handler("on_frame_reached_downstream") |
85 | 88 | async def on_frame_reached_downstream(task, frame): |
86 | 89 | if isinstance(frame, FunctionActionFrame): |
87 | 90 | await frame.function(frame.action, flow_manager) |
88 | 91 | self.function_finished_event.set() |
| 92 | + elif isinstance(frame, BotStoppedSpeakingFrame): |
| 93 | + await self._execute_deferred_post_actions() |
89 | 94 |
|
90 | 95 | def _register_action(self, action_type: str, handler: Callable) -> None: |
91 | 96 | """Register a handler for a specific action type. |
@@ -159,6 +164,25 @@ async def execute_actions(self, actions: Optional[List[ActionConfig]]) -> None: |
159 | 164 | except Exception as e: |
160 | 165 | raise ActionError(f"Failed to execute action {action_type}: {str(e)}") from e |
161 | 166 |
|
| 167 | + def schedule_deferred_post_actions(self, post_actions: List[ActionConfig]) -> None: |
| 168 | + """Schedule "deferred" post-actions to be executed after next LLM completion. |
| 169 | +
|
| 170 | + Args: |
| 171 | + post_actions: List of actions to execute |
| 172 | + """ |
| 173 | + self._deferred_post_actions = post_actions |
| 174 | + |
| 175 | + def clear_deferred_post_actions(self) -> None: |
| 176 | + """Clear any scheduled deferred post-actions.""" |
| 177 | + self._deferred_post_actions = [] |
| 178 | + |
| 179 | + async def _execute_deferred_post_actions(self) -> None: |
| 180 | + """Execute deferred post-actions.""" |
| 181 | + actions = self._deferred_post_actions |
| 182 | + self._deferred_post_actions = [] |
| 183 | + if actions: |
| 184 | + await self.execute_actions(actions) |
| 185 | + |
162 | 186 | async def _handle_tts_action(self, action: dict) -> None: |
163 | 187 | """Built-in handler for TTS actions. |
164 | 188 |
|
@@ -209,7 +233,7 @@ async def _handle_function_action(self, action: dict) -> None: |
209 | 233 | if not handler: |
210 | 234 | logger.error("Function action missing 'handler' field") |
211 | 235 | return |
212 | | - # the reason we're queueing a frame here is to ensure it happens after bot turn is over in |
| 236 | + # the reason we're queueing a frame here is to ensure it happens after bot turn is over in |
213 | 237 | # post_actions |
214 | 238 | await self.task.queue_frame(FunctionActionFrame(action=action, function=handler)) |
215 | 239 | await self.function_finished_event.wait() |
|
0 commit comments