diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/index.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/index.ts new file mode 100644 index 00000000..dd46bde6 --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/index.ts @@ -0,0 +1,26 @@ +// リファクタリングされたフックのエクスポート + +export { useMessages } from './useMessages' +export { useRequestControl } from './useRequestControl' +export { useChatUIState } from './useChatUIState' +export { useSessionManager } from './useSessionManager' +export { useToolExecution } from './useToolExecution' +export { useStreamChat } from './useStreamChat' +export { useAgentChatRefactored } from './useAgentChatRefactored' + +// 型定義のエクスポート +export type { UseMessagesProps, UseMessagesReturn } from './useMessages' +export type { UseRequestControlReturn } from './useRequestControl' +export type { UseChatUIStateReturn } from './useChatUIState' +export type { UseSessionManagerProps, UseSessionManagerReturn } from './useSessionManager' +export type { UseToolExecutionProps, UseToolExecutionReturn } from './useToolExecution' +export type { UseStreamChatProps, UseStreamChatReturn } from './useStreamChat' + +export type { + ChatUIState, + SessionInfo, + ToolExecutionState, + MessageOperations, + StreamingState, + UseAgentChatReturn +} from './types' diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/types.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/types.ts new file mode 100644 index 00000000..01654d81 --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/types.ts @@ -0,0 +1,59 @@ +// 共通の型定義ファイル + +import { IdentifiableMessage } from '@/types/chat/message' +import { ToolName } from '@/types/tools' + +// UI状態の型 +export interface ChatUIState { + loading: boolean + reasoning: boolean + executingTool: ToolName | null + latestReasoningText: string +} + +// セッション管理の型 +export interface SessionInfo { + currentSessionId: string | undefined + enableHistory: boolean +} + +// ツール実行の型 +export interface ToolExecutionState { + isExecuting: boolean + currentTool: ToolName | null +} + +// メッセージ管理の型 +export interface MessageOperations { + addMessage: (message: IdentifiableMessage) => Promise + persistMessage: (message: IdentifiableMessage) => Promise + clearMessages: () => void +} + +// ストリーミングの型 +export interface StreamingState { + isStreaming: boolean + abortController: AbortController | null +} + +// リファクタリング後のフックの戻り値の型 +export interface UseAgentChatReturn { + // メッセージ関連 + messages: IdentifiableMessage[] + setMessages: React.Dispatch> + + // UI状態 + loading: boolean + reasoning: boolean + executingTool: ToolName | null + latestReasoningText: string + + // セッション管理 + currentSessionId: string | undefined + setCurrentSessionId: (sessionId: string) => void + clearChat: () => Promise + + // メイン機能 + handleSubmit: (userInput: string, attachedImages?: any[]) => Promise + stopGeneration: () => void +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useAgentChatRefactored.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useAgentChatRefactored.ts new file mode 100644 index 00000000..05215869 --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useAgentChatRefactored.ts @@ -0,0 +1,337 @@ +import { useMemo, useCallback, useEffect } from 'react' +import { ImageFormat } from '@aws-sdk/client-bedrock-runtime' +import { ToolState } from '@/types/agent-chat' +import { useSettings } from '@renderer/contexts/SettingsContext' +import { isMcpTool } from '@/types/tools' +import { useAgentTools } from '../useAgentTools' +import toast from 'react-hot-toast' +import { useTranslation } from 'react-i18next' +import { notificationService } from '@renderer/services/NotificationService' +// Import type will be resolved from the original location +type AttachedImage = { + file: File + base64: string +} + +// 分割されたフックをインポート +import { useMessages } from './useMessages' +import { useRequestControl } from './useRequestControl' +import { useChatUIState } from './useChatUIState' +import { useSessionManager } from './useSessionManager' +import { useToolExecution } from './useToolExecution' +import { useStreamChat } from './useStreamChat' + +export const useAgentChatRefactored = ( + modelId: string, + systemPrompt?: string, + agentId?: string, + sessionId?: string, + options?: { + enableHistory?: boolean + tools?: ToolState[] + } +) => { + const { enableHistory = true, tools: explicitTools } = options || {} + const { t } = useTranslation() + + const { + notification, + contextLength, + guardrailSettings, + getAgentTools, + agents, + enablePromptCache + } = useSettings() + + // エージェントIDからツール設定を取得 + const rawEnabledTools = useMemo(() => { + if (explicitTools) { + return explicitTools.filter((tool) => tool.enabled) + } else if (agentId) { + const currentAgent = agents.find((a) => a.id === agentId) + const hasMcpServers = currentAgent?.mcpServers && currentAgent.mcpServers.length > 0 + const agentTools = getAgentTools(agentId).filter((tool) => tool.enabled) + + return agentTools.filter((tool) => { + const toolName = tool.toolSpec?.name + if (!toolName) return false + + if (toolName === 'tavilySearch') { + const tavilyApiKey = window.store.get('tavilySearch')?.apikey + return !!tavilyApiKey && tavilyApiKey.length > 0 + } + + if (isMcpTool(toolName)) { + if (!hasMcpServers) { + console.warn( + `MCP tool "${toolName}" is enabled but no MCP servers are configured. Tool will be disabled.` + ) + return false + } + } + return true + }) + } + return [] + }, [agentId, getAgentTools, explicitTools, agents]) + + const enabledTools = useAgentTools(rawEnabledTools) + + // 分割されたフックを使用 + const requestControl = useRequestControl() + const uiState = useChatUIState() + const sessionManager = useSessionManager({ + modelId, + systemPrompt, + sessionId, + enableHistory + }) + + const messagesHook = useMessages({ + enableHistory, + modelId, + enabledTools + }) + + const toolExecution = useToolExecution({ + guardrailSettings, + setExecutingTool: uiState.setExecutingTool + }) + + const streamChat = useStreamChat({ + contextLength, + enablePromptCache, + setMessages: messagesHook.setMessages, + setReasoning: uiState.setReasoning, + setLatestReasoningText: uiState.setLatestReasoningText, + persistMessage: async (msg, sessionId) => { + const targetSessionId = sessionId || sessionManager.currentSessionId + if (targetSessionId) { + return await messagesHook.persistMessage(msg, targetSessionId) + } + return msg + } + }) + + // セッション初期化時にメッセージを同期 + useEffect(() => { + if (sessionManager.currentSessionId && enableHistory) { + // セッションからメッセージを読み込む処理をここに追加 + // 現在の実装では useChatHistory から直接取得する必要がある + } + }, [sessionManager.currentSessionId, enableHistory]) + + const stopGeneration = useCallback(() => { + requestControl.abortCurrentRequest() + + if (messagesHook.messages.length > 0) { + // 不完全なtoolUse/toolResultペアを削除するロジック + const updatedMessages = [...messagesHook.messages] + const toolUseIds = new Map() + + updatedMessages.forEach((msg, msgIndex) => { + if (!msg.content) return + + msg.content.forEach((content) => { + if ('toolUse' in content && content.toolUse?.toolUseId) { + const toolUseId = content.toolUse.toolUseId + const entry = toolUseIds.get(toolUseId) || { useIndex: -1, resultIndex: -1 } + entry.useIndex = msgIndex + toolUseIds.set(toolUseId, entry) + } + + if ('toolResult' in content && content.toolResult?.toolUseId) { + const toolUseId = content.toolResult.toolUseId + const entry = toolUseIds.get(toolUseId) || { useIndex: -1, resultIndex: -1 } + entry.resultIndex = msgIndex + toolUseIds.set(toolUseId, entry) + } + }) + }) + + const indicesToDelete = new Set() + toolUseIds.forEach(({ useIndex, resultIndex }) => { + if (useIndex >= 0 && resultIndex === -1) { + indicesToDelete.add(useIndex) + } + }) + + const sortedIndicesToDelete = [...indicesToDelete].sort((a, b) => b - a) + if (sortedIndicesToDelete.length > 0) { + for (const index of sortedIndicesToDelete) { + updatedMessages.splice(index, 1) + } + messagesHook.setMessages(updatedMessages) + toast.success(t('Generation stopped')) + } else { + toast.success(t('Generation stopped')) + } + } + + uiState.setLoading(false) + uiState.setExecutingTool(null) + }, [requestControl, messagesHook, uiState, t]) + + const handleSubmit = useCallback( + async (userInput: string, attachedImages?: AttachedImage[]) => { + if (!userInput && (!attachedImages || attachedImages.length === 0)) { + return toast.error('Please enter a message or attach images') + } + + if (!modelId) { + return toast.error('Please select a model') + } + + try { + uiState.setLoading(true) + const currentMessages = [...messagesHook.messages] + + const imageContents: any = + attachedImages?.map((image) => ({ + image: { + format: image.file.type.split('/')[1] as ImageFormat, + source: { + bytes: image.base64 + } + } + })) ?? [] + + const textContent = guardrailSettings.enabled + ? { + guardContent: { + text: { + text: userInput + } + } + } + : { + text: userInput + } + + const content = imageContents.length > 0 ? [...imageContents, textContent] : [textContent] + + const userMessage = await messagesHook.addUserMessage( + content, + sessionManager.currentSessionId + ) + currentMessages.push(userMessage) + + const abortController = requestControl.createNewAbortController() + + await streamChat.streamChat( + { + messages: currentMessages, + modelId, + system: systemPrompt ? [{ text: systemPrompt }] : undefined, + toolConfig: enabledTools.length ? { tools: enabledTools } : undefined + }, + currentMessages, + abortController.signal + ) + + const lastMessage = currentMessages[currentMessages.length - 1] + if (lastMessage.content?.find((v) => v.toolUse)) { + if (lastMessage.content) { + await toolExecution.executeToolsRecursively( + lastMessage.content, + currentMessages, + async (msg) => { + if (sessionManager.currentSessionId) { + return await messagesHook.persistMessage(msg, sessionManager.currentSessionId) + } + return msg + }, + streamChat.streamChat, + modelId, + systemPrompt, + enabledTools + ) + } + } + + // タイトル生成のチェック + await sessionManager.generateTitleForCurrentSession(messagesHook.messages) + + // 通知の表示 + if (notification) { + const lastAssistantMessage = currentMessages + .filter((msg) => msg.role === 'assistant') + .pop() + let notificationBody = '' + + if (lastAssistantMessage?.content) { + const textContent = lastAssistantMessage.content + .filter((content) => 'text' in content) + .map((content) => (content as { text: string }).text) + .join(' ') + + notificationBody = textContent + .split(/[.。]/) + .filter((sentence) => sentence.trim().length > 0) + .slice(0, 2) + .join('. ') + .trim() + + if (notificationBody.length > 100) { + notificationBody = notificationBody.substring(0, 100) + '...' + } + } + + if (!notificationBody) { + notificationBody = t('notification.messages.chatComplete.body') + } + + await notificationService.showNotification( + t('notification.messages.chatComplete.title'), + { + body: notificationBody, + silent: false + } + ) + } + } catch (error: any) { + console.error('Error in handleSubmit:', error) + toast.error(error.message || 'An error occurred') + } finally { + uiState.setLoading(false) + uiState.setExecutingTool(null) + } + + return // Add explicit return for TypeScript + }, + [ + modelId, + systemPrompt, + enabledTools, + guardrailSettings, + messagesHook, + sessionManager, + requestControl, + streamChat, + toolExecution, + uiState, + notification, + t + ] + ) + + const clearChat = useCallback(async () => { + requestControl.abortCurrentRequest() + await sessionManager.clearSession() + messagesHook.clearMessages() + }, [requestControl, sessionManager, messagesHook]) + + return { + messages: messagesHook.messages, + loading: uiState.loading, + reasoning: uiState.reasoning, + executingTool: uiState.executingTool, + latestReasoningText: uiState.latestReasoningText, + handleSubmit, + setMessages: messagesHook.setMessages, + currentSessionId: sessionManager.currentSessionId, + setCurrentSessionId: sessionManager.setCurrentSessionId, + clearChat, + stopGeneration + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useChatUIState.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useChatUIState.ts new file mode 100644 index 00000000..a6cc256f --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useChatUIState.ts @@ -0,0 +1,41 @@ +import { useState, useCallback } from 'react' +import { ToolName } from '@/types/tools' + +export interface UseChatUIStateReturn { + loading: boolean + reasoning: boolean + executingTool: ToolName | null + latestReasoningText: string + setLoading: (value: boolean) => void + setReasoning: (value: boolean) => void + setExecutingTool: (tool: ToolName | null) => void + setLatestReasoningText: (text: string) => void + resetUIState: () => void +} + +export const useChatUIState = (): UseChatUIStateReturn => { + const [loading, setLoading] = useState(false) + const [reasoning, setReasoning] = useState(false) + const [executingTool, setExecutingTool] = useState(null) + const [latestReasoningText, setLatestReasoningText] = useState('') + + // UI状態をリセットする関数 + const resetUIState = useCallback(() => { + setLoading(false) + setReasoning(false) + setExecutingTool(null) + setLatestReasoningText('') + }, []) + + return { + loading, + reasoning, + executingTool, + latestReasoningText, + setLoading, + setReasoning, + setExecutingTool, + setLatestReasoningText, + resetUIState + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useMessages.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useMessages.ts new file mode 100644 index 00000000..5db26a02 --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useMessages.ts @@ -0,0 +1,107 @@ +import { useState, useCallback } from 'react' +import { generateMessageId } from '@/types/chat/metadata' +import { IdentifiableMessage } from '@/types/chat/message' +import { ChatMessage } from '@/types/chat/history' +import { useChatHistory } from '@renderer/contexts/ChatHistoryContext' + +export interface UseMessagesProps { + enableHistory?: boolean + modelId: string + enabledTools: any[] +} + +export interface UseMessagesReturn { + messages: IdentifiableMessage[] + setMessages: React.Dispatch> + persistMessage: (message: IdentifiableMessage, sessionId?: string) => Promise + addUserMessage: (content: any[], sessionId?: string) => Promise + addAssistantMessage: (content: any[], id?: string) => IdentifiableMessage + clearMessages: () => void +} + +export const useMessages = ({ + enableHistory = true, + modelId, + enabledTools +}: UseMessagesProps): UseMessagesReturn => { + const [messages, setMessages] = useState([]) + const { addMessage } = useChatHistory() + + // メッセージの永続化を行うラッパー関数 + const persistMessage = useCallback( + async (message: IdentifiableMessage, sessionId?: string) => { + if (!enableHistory) return message + if (!sessionId) return message + + if (message.role && message.content) { + // メッセージにIDがなければ生成する + if (!message.id) { + message.id = generateMessageId() + } + + const chatMessage: ChatMessage = { + id: message.id, + role: message.role, + content: message.content, + timestamp: Date.now(), + metadata: { + modelId, + tools: enabledTools, + converseMetadata: message.metadata?.converseMetadata + } + } + await addMessage(sessionId, chatMessage) + } + + return message + }, + [modelId, enabledTools, enableHistory, addMessage] + ) + + // ユーザーメッセージを追加 + const addUserMessage = useCallback( + async (content: any[], sessionId?: string): Promise => { + const userMessage: IdentifiableMessage = { + role: 'user', + content, + id: generateMessageId() + } + + setMessages((prev) => [...prev, userMessage]) + + if (sessionId) { + await persistMessage(userMessage, sessionId) + } + + return userMessage + }, + [persistMessage] + ) + + // アシスタントメッセージを追加 + const addAssistantMessage = useCallback((content: any[], id?: string): IdentifiableMessage => { + const messageId = id || generateMessageId() + const assistantMessage: IdentifiableMessage = { + role: 'assistant', + content, + id: messageId + } + + setMessages((prev) => [...prev, assistantMessage]) + return assistantMessage + }, []) + + // メッセージをクリア + const clearMessages = useCallback(() => { + setMessages([]) + }, []) + + return { + messages, + setMessages, + persistMessage, + addUserMessage, + addAssistantMessage, + clearMessages + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useRequestControl.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useRequestControl.ts new file mode 100644 index 00000000..6abc9aaa --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useRequestControl.ts @@ -0,0 +1,54 @@ +import { useRef, useCallback, useEffect } from 'react' + +export interface UseRequestControlReturn { + abortCurrentRequest: () => void + createNewAbortController: () => AbortController + getCurrentSignal: () => AbortSignal | null + isRequestActive: () => boolean +} + +export const useRequestControl = (): UseRequestControlReturn => { + const abortController = useRef(null) + + // 現在の通信を中断する関数 + const abortCurrentRequest = useCallback(() => { + if (abortController.current) { + abortController.current.abort() + abortController.current = null + } + }, []) + + // 新しい AbortController を作成する関数 + const createNewAbortController = useCallback(() => { + // 既存のコントローラーがあれば中断 + abortCurrentRequest() + + // 新しい AbortController を作成 + abortController.current = new AbortController() + return abortController.current + }, [abortCurrentRequest]) + + // 現在のシグナルを取得 + const getCurrentSignal = useCallback((): AbortSignal | null => { + return abortController.current?.signal || null + }, []) + + // リクエストがアクティブかどうかを判定 + const isRequestActive = useCallback((): boolean => { + return abortController.current !== null && !abortController.current.signal.aborted + }, []) + + // コンポーネントのアンマウント時にアクティブな通信を中断 + useEffect(() => { + return () => { + abortCurrentRequest() + } + }, [abortCurrentRequest]) + + return { + abortCurrentRequest, + createNewAbortController, + getCurrentSignal, + isRequestActive + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useSessionManager.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useSessionManager.ts new file mode 100644 index 00000000..434b663d --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useSessionManager.ts @@ -0,0 +1,133 @@ +import { useState, useEffect, useCallback, useRef } from 'react' +import { useChatHistory } from '@renderer/contexts/ChatHistoryContext' +import { generateSessionTitle } from '../../utils/titleGenerator' +import { useLightProcessingModel } from '@renderer/lib/modelSelection' +import { useTranslation } from 'react-i18next' +import { IdentifiableMessage } from '@/types/chat/message' + +export interface UseSessionManagerProps { + modelId: string + systemPrompt?: string + sessionId?: string + enableHistory?: boolean +} + +export interface UseSessionManagerReturn { + currentSessionId: string | undefined + setCurrentSessionId: (sessionId: string) => void + clearSession: () => Promise + initializeSession: () => Promise + generateTitleForCurrentSession: (messages: IdentifiableMessage[]) => Promise +} + +export const useSessionManager = ({ + modelId, + systemPrompt, + sessionId, + enableHistory = true +}: UseSessionManagerProps): UseSessionManagerReturn => { + const [currentSessionId, setCurrentSessionIdState] = useState(sessionId) + + // タイトル生成済みフラグ(同じセッションで複数回生成しないため) + const titleGenerated = useRef>(new Set()) + const MESSAGE_THRESHOLD = 4 // タイトル生成のためのメッセージ数閾値 + + const { t } = useTranslation() + const { getLightModelId } = useLightProcessingModel() + + const { getSession, createSession, updateSessionTitle, setActiveSession } = useChatHistory() + + // セッションの初期化 + const initializeSession = useCallback(async () => { + if (sessionId) { + const session = getSession(sessionId) + if (session) { + setCurrentSessionIdState(sessionId) + } + } else if (enableHistory) { + // 履歴保存が有効な場合のみ新しいセッションを作成 + const newSessionId = await createSession('defaultAgent', modelId, systemPrompt) + setCurrentSessionIdState(newSessionId) + } + }, [sessionId, enableHistory, getSession, createSession, modelId, systemPrompt]) + + // セッションを切り替える + const setCurrentSessionId = useCallback( + (newSessionId: string) => { + setCurrentSessionIdState(newSessionId) + if (newSessionId) { + setActiveSession(newSessionId) + } + }, + [setActiveSession] + ) + + // セッションをクリア(新しいセッションを作成) + const clearSession = useCallback(async () => { + if (!enableHistory) { + setCurrentSessionIdState(undefined) + return + } + + const newSessionId = await createSession('defaultAgent', modelId, systemPrompt) + setCurrentSessionIdState(newSessionId) + }, [enableHistory, createSession, modelId, systemPrompt]) + + // 現在のセッションにタイトルを生成する関数 + const generateTitleForCurrentSession = useCallback(async () => { + if (!currentSessionId || !enableHistory) return + + // このセッションIDをタイトル生成済みとしてマーク + titleGenerated.current.add(currentSessionId) + + try { + // セッションの詳細を取得 + const session = getSession(currentSessionId) + if (!session) return + + // セッションのタイトルが既にカスタマイズされている場合は生成しない + // "Chat "で始まるデフォルトタイトルのみ置き換える + if (!session.title.startsWith('Chat ')) return + + // 軽量処理用モデルIDを取得 + const lightModelId = getLightModelId() + + // 軽量モデルでタイトルを生成 + const newTitle = await generateSessionTitle(session, lightModelId, t) + if (newTitle) { + await updateSessionTitle(currentSessionId, newTitle) + } + } catch (error) { + console.error('Error generating title for current session:', error) + } + }, [currentSessionId, enableHistory, getSession, getLightModelId, t, updateSessionTitle]) + + // メッセージ数を監視してタイトル生成をトリガーする関数 + const checkAndGenerateTitle = useCallback( + async (messages: IdentifiableMessage[]) => { + // メッセージが閾値を超え、まだタイトルが生成されていない場合に実行 + if ( + messages.length > MESSAGE_THRESHOLD && + currentSessionId && + !titleGenerated.current.has(currentSessionId) && + enableHistory + ) { + await generateTitleForCurrentSession() + } + }, + [currentSessionId, generateTitleForCurrentSession, enableHistory] + ) + + // 初期化 + useEffect(() => { + initializeSession() + }, [initializeSession]) + + return { + currentSessionId, + setCurrentSessionId, + clearSession, + initializeSession, + generateTitleForCurrentSession: checkAndGenerateTitle + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useStreamChat.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useStreamChat.ts new file mode 100644 index 00000000..5e0360ec --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useStreamChat.ts @@ -0,0 +1,458 @@ +import { useRef, useCallback } from 'react' +import { + ConversationRole, + ContentBlock, + Message, + ToolUseBlockStart +} from '@aws-sdk/client-bedrock-runtime' +import { generateMessageId } from '@/types/chat/metadata' +import { IdentifiableMessage } from '@/types/chat/message' +import { StreamChatCompletionProps, streamChatCompletion } from '@renderer/lib/api' +import { limitContextLength } from '@renderer/lib/contextLength' +import { + addCachePointsToMessages, + addCachePointToSystem, + addCachePointToTools, + logCacheUsage +} from '@renderer/lib/promptCacheUtils' +import { calculateCost } from '@renderer/lib/pricing/modelPricing' +import toast from 'react-hot-toast' +import { useTranslation } from 'react-i18next' + +// メッセージの送信時に、Trace を全て載せると InputToken が逼迫するので取り除く +function removeTraces(messages: any[]) { + return messages.map((message) => { + if (message.content && Array.isArray(message.content)) { + return { + ...message, + content: message.content.map((item) => { + if (item.toolResult) { + return { + ...item, + toolResult: { + ...item.toolResult, + content: item.toolResult.content.map((c) => { + if (c?.json?.result?.completion) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { traces, ...restCompletion } = c.json.result.completion + return { + ...c, + json: { + ...c.json, + result: { + ...c.json.result, + completion: restCompletion + } + } + } + } + return c + }) + } + } + } + return item + }) + } + } + return message + }) +} + +export interface UseStreamChatProps { + contextLength: number + enablePromptCache: boolean + setMessages: React.Dispatch> + setReasoning: (value: boolean) => void + setLatestReasoningText: (text: string) => void + persistMessage: (message: IdentifiableMessage, sessionId?: string) => Promise +} + +export interface UseStreamChatReturn { + streamChat: ( + props: StreamChatCompletionProps, + currentMessages: Message[], + abortSignal: AbortSignal + ) => Promise +} + +export const useStreamChat = ({ + contextLength, + enablePromptCache, + setMessages, + setReasoning, + setLatestReasoningText, + persistMessage +}: UseStreamChatProps): UseStreamChatReturn => { + const { t } = useTranslation() + + // キャッシュポイントを保持するための参照 + const lastCachePoint = useRef(undefined) + const lastAssistantMessageId = useRef(null) + + const streamChat = useCallback( + async ( + props: StreamChatCompletionProps, + currentMessages: Message[], + abortSignal: AbortSignal + ) => { + // Context長に基づいてメッセージを制限 + const limitedMessages = removeTraces(limitContextLength(currentMessages, contextLength)) + + // キャッシュポイントを追加(前回のキャッシュポイントを引き継ぐ) + props.messages = enablePromptCache + ? addCachePointsToMessages(limitedMessages, props.modelId, lastCachePoint.current) + : limitedMessages + + // キャッシュポイントが更新された場合、次回の会話ためにキャッシュポイントのインデックスを更新 + if (props.messages[props.messages.length - 1].content?.some((b) => b.cachePoint?.type)) { + // 次回の会話のために現在のキャッシュポイントを更新 + // 現在のメッセージ配列の最後のインデックスを次回の最初のキャッシュポイントとして設定 + lastCachePoint.current = props.messages.length - 1 + } + + // システムプロンプトとツール設定にもキャッシュポイントを追加 + if (props.system && enablePromptCache) { + props.system = addCachePointToSystem(props.system, props.modelId) + } + + if (props.toolConfig && enablePromptCache) { + props.toolConfig = addCachePointToTools(props.toolConfig, props.modelId) + } + + const generator = streamChatCompletion(props, abortSignal) + + let s = '' + let reasoningContentText = '' + let reasoningContentSignature = '' + let redactedContent + let input = '' + let role: ConversationRole = 'assistant' + let toolUse: ToolUseBlockStart | undefined = undefined + let stopReason + const content: ContentBlock[] = [] + + let messageStart = false + try { + for await (const json of generator) { + if (json.messageStart) { + role = json.messageStart.role ?? 'assistant' + messageStart = true + } else if (json.messageStop) { + if (!messageStart) { + console.warn('messageStop without messageStart') + console.log(currentMessages) + await streamChat(props, currentMessages, abortSignal) + return + } + // 新しいメッセージIDを生成 + const messageId = generateMessageId() + const newMessage: IdentifiableMessage = { role, content, id: messageId } + + // アシスタントメッセージの場合、最後のメッセージIDを保持 + if (role === 'assistant') { + lastAssistantMessageId.current = messageId + } + + // UI表示のために即時メッセージを追加 + setMessages([...currentMessages, newMessage]) + currentMessages.push(newMessage) + + stopReason = json.messageStop.stopReason + } else if (json.contentBlockStart) { + toolUse = json.contentBlockStart.start?.toolUse + } else if (json.contentBlockStop) { + if (toolUse) { + let parseInput: string + try { + parseInput = JSON.parse(input) + } catch (e) { + parseInput = input + } + + content.push({ + toolUse: { name: toolUse?.name, toolUseId: toolUse?.toolUseId, input: parseInput } + }) + } else { + if (s.length > 0) { + const getReasoningBlock = () => { + if (reasoningContentText.length > 0) { + return { + reasoningContent: { + reasoningText: { + text: reasoningContentText, + signature: reasoningContentSignature + } + } + } + } else if (reasoningContentSignature.length > 0) { + return { + reasoningContent: { + redactedContent: redactedContent + } + } + } else { + return null + } + } + + const reasoningBlock = getReasoningBlock() + const contentBlocks = reasoningBlock ? [reasoningBlock, { text: s }] : [{ text: s }] + content.push(...contentBlocks) + } + } + input = '' + setReasoning(false) + } else if (json.contentBlockDelta) { + const text = json.contentBlockDelta.delta?.text + if (text) { + s = s + text + + const getContentBlocks = () => { + if (redactedContent) { + return [ + { + reasoningContent: { + redactedContent: redactedContent + } + }, + { text: s } + ] + } else if (reasoningContentText.length > 0) { + return [ + { + reasoningContent: { + reasoningText: { + text: reasoningContentText, + signature: reasoningContentSignature + } + } + }, + { text: s } + ] + } else { + return [{ text: s }] + } + } + + const contentBlocks = getContentBlocks() + setMessages([...currentMessages, { role, content: contentBlocks }]) + } + + const reasoningContent = json.contentBlockDelta.delta?.reasoningContent + if (reasoningContent) { + setReasoning(true) + if (reasoningContent?.text || reasoningContent?.signature) { + reasoningContentText = reasoningContentText + (reasoningContent?.text || '') + reasoningContentSignature = reasoningContent?.signature || '' + + // 最新のreasoningTextを状態として保持 + if (reasoningContent?.text) { + setLatestReasoningText(reasoningContentText) + } + + setMessages([ + ...currentMessages, + { + role: 'assistant', + content: [ + { + reasoningContent: { + reasoningText: { + text: reasoningContentText, + signature: reasoningContentSignature + } + } + }, + { text: s } + ] + } + ]) + } else if (reasoningContent.redactedContent) { + redactedContent = reasoningContent.redactedContent + setMessages([ + ...currentMessages, + { + role: 'assistant', + content: [ + { + reasoningContent: { + redactedContent: reasoningContent.redactedContent + } + }, + { text: s } + ] + } + ]) + } + } + + if (toolUse) { + input = input + json.contentBlockDelta.delta?.toolUse?.input + + const getContentBlocks = () => { + if (redactedContent) { + return [ + { + reasoningContent: { + redactedContent: redactedContent + } + }, + { text: s }, + { + toolUse: { name: toolUse?.name, toolUseId: toolUse?.toolUseId, input: input } + } + ] + } else if (reasoningContentText.length > 0) { + return [ + { + reasoningContent: { + reasoningText: { + text: reasoningContentText, + signature: reasoningContentSignature + } + } + }, + { text: s }, + { + toolUse: { name: toolUse?.name, toolUseId: toolUse?.toolUseId, input: input } + } + ] + } else { + return [ + { text: s }, + { + toolUse: { name: toolUse?.name, toolUseId: toolUse?.toolUseId, input: input } + } + ] + } + } + + setMessages([ + ...currentMessages, + { + role, + content: getContentBlocks() + } + ]) + } + } else if (json.metadata) { + // Metadataを処理 + const metadata: IdentifiableMessage['metadata'] = { + converseMetadata: {}, + sessionCost: undefined + } + metadata.converseMetadata = json.metadata + + let sessionCost: number + // モデルIDがある場合、コストを計算 + if ( + props.modelId && + metadata.converseMetadata.usage && + metadata.converseMetadata.usage.inputTokens && + metadata.converseMetadata.usage.outputTokens + ) { + try { + sessionCost = calculateCost( + props.modelId, + metadata.converseMetadata.usage.inputTokens, + metadata.converseMetadata.usage.outputTokens, + metadata.converseMetadata.usage.cacheReadInputTokens, + metadata.converseMetadata.usage.cacheWriteInputTokens + ) + metadata.sessionCost = sessionCost + } catch (error) { + console.error('Error calculating cost:', error) + } + } + + // Prompt Cacheの使用状況をログ出力 + logCacheUsage(metadata.converseMetadata, props.modelId) + + // 直近のアシスタントメッセージにメタデータを関連付ける + if (lastAssistantMessageId.current) { + // メッセージ配列からIDが一致するメッセージを見つけてメタデータを追加 + setMessages((prevMessages) => { + return prevMessages.map((msg) => { + if (msg.id === lastAssistantMessageId.current) { + return { + ...msg, + metadata: { + ...msg.metadata, + converseMetadata: metadata.converseMetadata, + sessionCost: metadata.sessionCost + } + } + } + return msg + }) + }) + + // currentMessagesの最後(直近のメッセージ)を永続化する + const lastMessageIndex = currentMessages.length - 1 + const lastMessage = currentMessages[lastMessageIndex] + + if ( + lastMessage && + 'id' in lastMessage && + lastMessage.id === lastAssistantMessageId.current + ) { + // 型を明確にしてメタデータを追加 + const updatedMessage: IdentifiableMessage = { + ...(lastMessage as IdentifiableMessage), + metadata: { + ...(lastMessage as any).metadata, + converseMetadata: metadata.converseMetadata, + sessionCost: metadata.sessionCost + } + } + + // 配列の最後のメッセージを更新 + currentMessages[lastMessageIndex] = updatedMessage + + // メタデータを受信した時点で永続化を行う + await persistMessage(updatedMessage) + } + } + } else { + console.error('unexpected json:', json) + } + } + + return stopReason + } catch (error: any) { + if (error.name === 'AbortError') { + console.log('Chat stream aborted') + return + } + console.error({ streamChatRequestError: error }) + toast.error(t('request error')) + const messageId = generateMessageId() + const errorMessage: IdentifiableMessage = { + role: 'assistant' as const, + content: [{ text: error.message }], + id: messageId + } + + // エラーメッセージIDを記録 + lastAssistantMessageId.current = messageId + setMessages([...currentMessages, errorMessage]) + await persistMessage(errorMessage) + throw error + } + }, + [ + contextLength, + enablePromptCache, + setMessages, + setReasoning, + setLatestReasoningText, + persistMessage, + t + ] + ) + + return { + streamChat + } +} diff --git a/src/renderer/src/pages/ChatPage/hooks/refactored/useToolExecution.ts b/src/renderer/src/pages/ChatPage/hooks/refactored/useToolExecution.ts new file mode 100644 index 00000000..f93317bc --- /dev/null +++ b/src/renderer/src/pages/ChatPage/hooks/refactored/useToolExecution.ts @@ -0,0 +1,210 @@ +import { useCallback } from 'react' +import { ContentBlock } from '@aws-sdk/client-bedrock-runtime' +import { generateMessageId } from '@/types/chat/metadata' +import { IdentifiableMessage } from '@/types/chat/message' +import { ToolName } from '@/types/tools' +import toast from 'react-hot-toast' +import { useTranslation } from 'react-i18next' + +export interface UseToolExecutionProps { + guardrailSettings: { + enabled: boolean + guardrailIdentifier?: string + guardrailVersion?: string + } + setExecutingTool: (tool: ToolName | null) => void +} + +export interface UseToolExecutionReturn { + executeToolsRecursively: ( + contentBlocks: ContentBlock[], + currentMessages: any[], + persistMessage: (message: IdentifiableMessage, sessionId?: string) => Promise, + streamChat: (props: any, messages: any[], abortSignal: AbortSignal) => Promise, + modelId: string, + systemPrompt?: string, + enabledTools?: any[] + ) => Promise +} + +export const useToolExecution = ({ + guardrailSettings, + setExecutingTool +}: UseToolExecutionProps): UseToolExecutionReturn => { + const { t } = useTranslation() + + const executeToolsRecursively = useCallback( + async ( + contentBlocks: ContentBlock[], + currentMessages: any[], + persistMessage: (message: IdentifiableMessage, sessionId?: string) => Promise, + streamChat: (props: any, messages: any[], abortSignal: AbortSignal) => Promise, + modelId: string, + systemPrompt?: string, + enabledTools: any[] = [] + ) => { + const contentBlock = contentBlocks.find((block) => block.toolUse) + if (!contentBlock) { + return + } + + const toolResults: ContentBlock[] = [] + for (const contentBlock of contentBlocks) { + if (Object.keys(contentBlock).includes('toolUse')) { + const toolUse = contentBlock.toolUse + if (toolUse?.name) { + try { + const toolInput = { + type: toolUse.name, + ...(toolUse.input as any) + } + setExecutingTool(toolInput.type) + const toolResult = await window.api.bedrock.executeTool(toolInput) + setExecutingTool(null) + + // ツール実行結果用のContentBlockを作成 + let resultContentBlock: ContentBlock + if (Object.prototype.hasOwnProperty.call(toolResult, 'name')) { + resultContentBlock = { + toolResult: { + toolUseId: toolUse.toolUseId, + content: [{ json: toolResult as any }], + status: 'success' + } + } + } else { + resultContentBlock = { + toolResult: { + toolUseId: toolUse.toolUseId, + content: [{ text: toolResult as any }], + status: 'success' + } + } + } + + // GuardrailがActive状態であればチェック実行 + if ( + guardrailSettings.enabled && + guardrailSettings.guardrailIdentifier && + guardrailSettings.guardrailVersion + ) { + try { + console.log('Applying guardrail to tool result') + // ツール結果をガードレールで検証 + const toolResultText = + typeof toolResult === 'string' ? toolResult : JSON.stringify(toolResult) + + console.log({ toolResultText }) + // ツール結果をGuardrailで評価 + const guardrailResult = await window.api.bedrock.applyGuardrail({ + guardrailIdentifier: guardrailSettings.guardrailIdentifier, + guardrailVersion: guardrailSettings.guardrailVersion, + source: 'OUTPUT', // ツールからの出力をチェック + content: [ + { + text: { + text: toolResultText + } + } + ] + }) + console.log({ guardrailResult }) + + // ガードレールが介入した場合は代わりにエラーメッセージを使用 + if (guardrailResult.action === 'GUARDRAIL_INTERVENED') { + console.warn('Guardrail intervened for tool result', guardrailResult) + let errorMessage = t('guardrail.toolResult.blocked') + + // もしガードレールが出力を提供していれば、それを使用 + if (guardrailResult.outputs && guardrailResult.outputs.length > 0) { + const output = guardrailResult.outputs[0] + if (output.text) { + errorMessage = output.text + } + } + + // エラーステータスのツール結果を作成 + resultContentBlock = { + toolResult: { + toolUseId: toolUse.toolUseId, + content: [{ text: errorMessage }], + status: 'error' + } + } + + toast(t('guardrail.intervention'), { + icon: '⚠️', + style: { + backgroundColor: '#FEF3C7', // Light yellow background + color: '#92400E', // Amber text color + border: '1px solid #F59E0B' // Amber border + } + }) + } + } catch (guardrailError) { + console.error('Error applying guardrail to tool result:', guardrailError) + // ガードレールエラー時は元のツール結果を使用し続ける + } + } + + // 最終的なツール結果をコレクションに追加 + toolResults.push(resultContentBlock) + } catch (e: any) { + console.error(e) + toolResults.push({ + toolResult: { + toolUseId: toolUse.toolUseId, + content: [{ text: e.toString() }], + status: 'error' + } + }) + } + } + } + } + + const toolResultMessage: IdentifiableMessage = { + role: 'user', + content: toolResults, + id: generateMessageId() + } + currentMessages.push(toolResultMessage) + await persistMessage(toolResultMessage) + + // Create a new AbortController for this specific call + const abortController = new AbortController() + + const stopReason = await streamChat( + { + messages: currentMessages, + modelId, + system: systemPrompt ? [{ text: systemPrompt }] : undefined, + toolConfig: enabledTools.length ? { tools: enabledTools } : undefined + }, + currentMessages, + abortController.signal + ) + + if (stopReason === 'tool_use') { + const lastMessage = currentMessages[currentMessages.length - 1].content + if (lastMessage) { + await executeToolsRecursively( + lastMessage, + currentMessages, + persistMessage, + streamChat, + modelId, + systemPrompt, + enabledTools + ) + return + } + } + }, + [guardrailSettings, setExecutingTool, t] + ) + + return { + executeToolsRecursively + } +}