From c589714b9c215c8eaa56ad90024e204c45001989 Mon Sep 17 00:00:00 2001 From: leonbeyourside Date: Fri, 13 Mar 2026 18:08:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E5=8E=8B=E7=BC=A9=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E5=88=A0=E9=99=A4=20user=20=E6=B6=88=E6=81=AF=20Bug=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/context/truncator.py | 110 +++++++++++++++++------- 1 file changed, 77 insertions(+), 33 deletions(-) diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index afd89f2bed..7360f753fa 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -12,6 +12,71 @@ def _has_tool_calls(self, message: Message) -> bool: and len(message.tool_calls) > 0 ) + def _split_system_and_rest( + self, messages: list[Message] + ) -> tuple[list[Message], list[Message]]: + """Split messages into system messages and the rest. + + Returns: + tuple: (system_messages, non_system_messages) + """ + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + + return messages[:first_non_system], messages[first_non_system:] + + def _ensure_first_user_message( + self, + system_messages: list[Message], + non_system_messages: list[Message], + original_messages: list[Message], + ) -> list[Message]: + """Ensure the result always contains the first user message right after + system messages. This is required by many LLM APIs (e.g. Zhipu) that + mandate a ``user`` message immediately following the ``system`` message. + + If the truncated ``non_system_messages`` already starts with a ``user`` + message, the list is returned as-is (with ``fix_messages`` applied). + Otherwise the first ``user`` message from the *original* full message + list is located and prepended. + + Args: + system_messages: The system messages extracted earlier. + non_system_messages: The truncated non-system messages. + original_messages: The full, untruncated message list (used to + locate the original first ``user`` message when it has been + removed by truncation). + + Returns: + A well-formed message list: ``system + [first_user +] rest``. + """ + # Fast path: already starts with a user message – nothing to fix. + if non_system_messages and non_system_messages[0].role == "user": + return self.fix_messages(system_messages + non_system_messages) + + # Locate the first user message from the *original* list. + first_user_msg: Message | None = None + for msg in original_messages: + if msg.role == "user": + first_user_msg = msg + break + + if first_user_msg is None: + # Degenerate case: no user message exists at all. + return self.fix_messages(system_messages + non_system_messages) + + # Avoid duplicate: if the located message is already in the truncated + # list (identity check), don't prepend again. + if any(m is first_user_msg for m in non_system_messages): + return self.fix_messages(system_messages + non_system_messages) + + # Prepend the first user message so the sequence is valid. + result = system_messages + [first_user_msg] + non_system_messages + return self.fix_messages(result) + def fix_messages(self, messages: list[Message]) -> list[Message]: """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 @@ -81,14 +146,7 @@ def truncate_by_turns( if keep_most_recent_turns == -1: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_and_rest(messages) if len(non_system_messages) // 2 <= keep_most_recent_turns: return messages @@ -107,9 +165,9 @@ def truncate_by_turns( if index is not None and index > 0: truncated_contexts = truncated_contexts[index:] - result = system_messages + truncated_contexts - - return self.fix_messages(result) + return self._ensure_first_user_message( + system_messages, truncated_contexts, messages + ) def truncate_by_dropping_oldest_turns( self, @@ -120,14 +178,7 @@ def truncate_by_dropping_oldest_turns( if drop_turns <= 0: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_and_rest(messages) if len(non_system_messages) // 2 <= drop_turns: truncated_non_system = [] @@ -143,9 +194,9 @@ def truncate_by_dropping_oldest_turns( elif truncated_non_system: truncated_non_system = [] - result = system_messages + truncated_non_system - - return self.fix_messages(result) + return self._ensure_first_user_message( + system_messages, truncated_non_system, messages + ) def truncate_by_halving( self, @@ -155,14 +206,7 @@ def truncate_by_halving( if len(messages) <= 2: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_and_rest(messages) messages_to_delete = len(non_system_messages) // 2 if messages_to_delete == 0: @@ -177,6 +221,6 @@ def truncate_by_halving( if index is not None: truncated_non_system = truncated_non_system[index:] - result = system_messages + truncated_non_system - - return self.fix_messages(result) + return self._ensure_first_user_message( + system_messages, truncated_non_system, messages + ) From 0f309c8d3bef9d67f109370c34823c39bde759d0 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 20 Mar 2026 10:34:26 +0800 Subject: [PATCH 2/2] perf: improve truncate algo --- astrbot/core/agent/context/truncator.py | 116 ++++++++++-------------- 1 file changed, 46 insertions(+), 70 deletions(-) diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index 7360f753fa..9abf574336 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -12,8 +12,9 @@ def _has_tool_calls(self, message: Message) -> bool: and len(message.tool_calls) > 0 ) - def _split_system_and_rest( - self, messages: list[Message] + @staticmethod + def _split_system_rest( + messages: list[Message], ) -> tuple[list[Message], list[Message]]: """Split messages into system messages and the rest. @@ -25,66 +26,36 @@ def _split_system_and_rest( if msg.role != "system": first_non_system = i break - return messages[:first_non_system], messages[first_non_system:] - def _ensure_first_user_message( - self, + @staticmethod + def _ensure_user_message( system_messages: list[Message], - non_system_messages: list[Message], + truncated: list[Message], original_messages: list[Message], ) -> list[Message]: """Ensure the result always contains the first user message right after system messages. This is required by many LLM APIs (e.g. Zhipu) that mandate a ``user`` message immediately following the ``system`` message. - - If the truncated ``non_system_messages`` already starts with a ``user`` - message, the list is returned as-is (with ``fix_messages`` applied). - Otherwise the first ``user`` message from the *original* full message - list is located and prepended. - - Args: - system_messages: The system messages extracted earlier. - non_system_messages: The truncated non-system messages. - original_messages: The full, untruncated message list (used to - locate the original first ``user`` message when it has been - removed by truncation). - - Returns: - A well-formed message list: ``system + [first_user +] rest``. """ - # Fast path: already starts with a user message – nothing to fix. - if non_system_messages and non_system_messages[0].role == "user": - return self.fix_messages(system_messages + non_system_messages) + if truncated and truncated[0].role == "user": + return system_messages + truncated # Locate the first user message from the *original* list. - first_user_msg: Message | None = None - for msg in original_messages: - if msg.role == "user": - first_user_msg = msg - break + first_user = next((m for m in original_messages if m.role == "user"), None) + if first_user is None: + return system_messages + truncated - if first_user_msg is None: - # Degenerate case: no user message exists at all. - return self.fix_messages(system_messages + non_system_messages) - - # Avoid duplicate: if the located message is already in the truncated - # list (identity check), don't prepend again. - if any(m is first_user_msg for m in non_system_messages): - return self.fix_messages(system_messages + non_system_messages) - - # Prepend the first user message so the sequence is valid. - result = system_messages + [first_user_msg] + non_system_messages - return self.fix_messages(result) + return system_messages + [first_user] + truncated def fix_messages(self, messages: list[Message]) -> list[Message]: - """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 + """Fix the message list to ensure the validity of tool call and tool response pairing. - 此方法确保: - 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息 - 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应 + This method ensures that: + 1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`. + 2. Each `assistant` message containing `tool_calls` is followed by corresponding ` - 这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。 + This is a requirement of the OpenAI Chat Completions API specification (Gemini enforces this strictly). """ if not messages: return messages @@ -103,24 +74,25 @@ def flush_pending_if_valid() -> None: for msg in messages: if msg.role == "tool": - # 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应 + # Only record tool responses when there is a pending assistant(tool_calls) if pending_assistant is not None: pending_tools.append(msg) - # else: 孤立的 tool 消息,直接忽略 + # Isolated tool messages without a preceding assistant(tool_calls) are ignored continue if self._has_tool_calls(msg): - # 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链 + # When encountering a new assistant(tool_calls), first process the old pending chain flush_pending_if_valid() pending_assistant = msg continue - # 非 tool,且不含 tool_calls 的消息 - # 先结束任何 pending 链,再正常追加 + # Non-tool messages that do not contain tool_calls will break the pending chain. + # Flush any pending chain first, then append the current message normally. flush_pending_if_valid() fixed_messages.append(msg) - # 结束时处理最后一个 pending 链 + # Flush the last pending chain at the end, + # ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list. flush_pending_if_valid() return fixed_messages @@ -131,22 +103,23 @@ def truncate_by_turns( keep_most_recent_turns: int, drop_turns: int = 1, ) -> list[Message]: - """截断上下文列表,确保不超过最大长度。 - 一个 turn 包含一个 user 消息和一个 assistant 消息。 - 这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。 + """ + Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns. + A turn consists of a user message and an assistant message. + This method ensures that the truncated context list conforms to OpenAI's context format. Args: - messages: 上下文列表 - keep_most_recent_turns: 保留最近的对话轮数 - drop_turns: 一次性丢弃的对话轮数 + messages: The original list of messages in the context. + keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation). + drop_turns: The number of turns to drop from the beginning. Returns: - 截断后的上下文列表 + The truncated list of messages. """ if keep_most_recent_turns == -1: return messages - system_messages, non_system_messages = self._split_system_and_rest(messages) + system_messages, non_system_messages = self._split_system_rest(messages) if len(non_system_messages) // 2 <= keep_most_recent_turns: return messages @@ -157,7 +130,7 @@ def truncate_by_turns( else: truncated_contexts = non_system_messages[-num_to_keep * 2 :] - # 找到第一个 role 为 user 的索引,确保上下文格式正确 + # Find the first user message index = next( (i for i, item in enumerate(truncated_contexts) if item.role == "user"), None, @@ -165,48 +138,49 @@ def truncate_by_turns( if index is not None and index > 0: truncated_contexts = truncated_contexts[index:] - return self._ensure_first_user_message( + result = self._ensure_user_message( system_messages, truncated_contexts, messages ) + return self.fix_messages(result) def truncate_by_dropping_oldest_turns( self, messages: list[Message], drop_turns: int = 1, ) -> list[Message]: - """丢弃最旧的 N 个对话轮次。""" + """Drop the oldest N turns, regardless of the number of turns to keep.""" if drop_turns <= 0: return messages - system_messages, non_system_messages = self._split_system_and_rest(messages) + system_messages, non_system_messages = self._split_system_rest(messages) if len(non_system_messages) // 2 <= drop_turns: truncated_non_system = [] else: truncated_non_system = non_system_messages[drop_turns * 2 :] + # Find the first user message index = next( (i for i, item in enumerate(truncated_non_system) if item.role == "user"), None, ) if index is not None: truncated_non_system = truncated_non_system[index:] - elif truncated_non_system: - truncated_non_system = [] - return self._ensure_first_user_message( + result = self._ensure_user_message( system_messages, truncated_non_system, messages ) + return self.fix_messages(result) def truncate_by_halving( self, messages: list[Message], ) -> list[Message]: - """对半砍策略,删除 50% 的消息""" + """Halve the number of messages, keeping the most recent ones.""" if len(messages) <= 2: return messages - system_messages, non_system_messages = self._split_system_and_rest(messages) + system_messages, non_system_messages = self._split_system_rest(messages) messages_to_delete = len(non_system_messages) // 2 if messages_to_delete == 0: @@ -214,6 +188,7 @@ def truncate_by_halving( truncated_non_system = non_system_messages[messages_to_delete:] + # Find the first user message index = next( (i for i, item in enumerate(truncated_non_system) if item.role == "user"), None, @@ -221,6 +196,7 @@ def truncate_by_halving( if index is not None: truncated_non_system = truncated_non_system[index:] - return self._ensure_first_user_message( + result = self._ensure_user_message( system_messages, truncated_non_system, messages ) + return self.fix_messages(result)