Skip to content

Commit eacc6a7

Browse files
committed
feat(chat): add support for client tool execution after approval
- refactor(stream): improve tool call part updates and message handling - test(chat): add tests for approval flow and full message sending Resolves #225 by ensuring client-side tools execute immediately after approval and fixing a state overwrite bug in the stream processor.
1 parent 8e93ce2 commit eacc6a7

File tree

9 files changed

+419
-80
lines changed

9 files changed

+419
-80
lines changed

packages/typescript/ai-client/src/chat-client.ts

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ export class ChatClient {
293293
this.abortController = new AbortController()
294294

295295
try {
296-
// Get model messages for the LLM
297-
const modelMessages = this.processor.toModelMessages()
296+
const messages = this.processor.getMessages()
298297

299298
// Call onResponse callback
300299
await this.callbacksRef.current.onResponse()
@@ -307,7 +306,7 @@ export class ChatClient {
307306

308307
// Connect and stream
309308
const stream = this.connection.connect(
310-
modelMessages,
309+
messages,
311310
bodyWithConversationId,
312311
this.abortController.signal,
313312
)
@@ -417,6 +416,8 @@ export class ChatClient {
417416
// Find the tool call ID from the approval ID
418417
const messages = this.processor.getMessages()
419418
let foundToolCallId: string | undefined
419+
let foundToolName: string | undefined
420+
let foundToolInput: any | undefined
420421

421422
for (const msg of messages) {
422423
const toolCallPart = msg.parts.find(
@@ -425,6 +426,12 @@ export class ChatClient {
425426
)
426427
if (toolCallPart) {
427428
foundToolCallId = toolCallPart.id
429+
foundToolName = toolCallPart.name
430+
try {
431+
foundToolInput = JSON.parse(toolCallPart.arguments)
432+
} catch {
433+
// Ignore parse errors
434+
}
428435
break
429436
}
430437
}
@@ -440,6 +447,32 @@ export class ChatClient {
440447
// Add response via processor
441448
this.processor.addToolApprovalResponse(response.id, response.approved)
442449

450+
// Execute client-side tool if approved
451+
if (response.approved && foundToolCallId && foundToolName) {
452+
const clientTool = this.clientToolsRef.current.get(foundToolName)
453+
if (clientTool?.execute) {
454+
try {
455+
const output = await clientTool.execute(foundToolInput)
456+
await this.addToolResult({
457+
toolCallId: foundToolCallId,
458+
tool: foundToolName,
459+
output,
460+
state: 'output-available',
461+
})
462+
return
463+
} catch (error: any) {
464+
await this.addToolResult({
465+
toolCallId: foundToolCallId,
466+
tool: foundToolName,
467+
output: null,
468+
state: 'output-error',
469+
errorText: error.message,
470+
})
471+
return
472+
}
473+
}
474+
}
475+
443476
// If stream is in progress, queue continuation check for after it ends
444477
if (this.isLoading) {
445478
this.queuePostStreamAction(() => this.checkForContinuation())

packages/typescript/ai-client/src/connection-adapters.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ export interface FetchConnectionOptions {
8989
signal?: AbortSignal
9090
body?: Record<string, any>
9191
fetchClient?: typeof globalThis.fetch
92+
/**
93+
* Send full UIMessage objects (including `parts`) instead of ModelMessages.
94+
* Required for advanced server features that depend on UIMessage metadata
95+
* (e.g. tool approvals and client tool results tracked in parts).
96+
*/
97+
sendFullMessages?: boolean
9298
}
9399

94100
/**
@@ -138,7 +144,9 @@ export function fetchServerSentEvents(
138144
const resolvedOptions =
139145
typeof options === 'function' ? await options() : options
140146

141-
const modelMessages = convertMessagesToModelMessages(messages)
147+
const requestMessages = resolvedOptions.sendFullMessages
148+
? messages
149+
: convertMessagesToModelMessages(messages)
142150

143151
const requestHeaders: Record<string, string> = {
144152
'Content-Type': 'application/json',
@@ -147,7 +155,7 @@ export function fetchServerSentEvents(
147155

148156
// Merge body from options with messages and data
149157
const requestBody = {
150-
messages: modelMessages,
158+
messages: requestMessages,
151159
data,
152160
...resolvedOptions.body,
153161
}
@@ -238,8 +246,9 @@ export function fetchHttpStream(
238246
const resolvedOptions =
239247
typeof options === 'function' ? await options() : options
240248

241-
// Convert UIMessages to ModelMessages if needed
242-
const modelMessages = convertMessagesToModelMessages(messages)
249+
const requestMessages = resolvedOptions.sendFullMessages
250+
? messages
251+
: convertMessagesToModelMessages(messages)
243252

244253
const requestHeaders: Record<string, string> = {
245254
'Content-Type': 'application/json',
@@ -248,7 +257,7 @@ export function fetchHttpStream(
248257

249258
// Merge body from options with messages and data
250259
const requestBody = {
251-
messages: modelMessages,
260+
messages: requestMessages,
252261
data,
253262
...resolvedOptions.body,
254263
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import { describe, expect, it, vi } from 'vitest'
2+
import { ChatClient } from '../src/chat-client'
3+
import { stream } from '../src/connection-adapters'
4+
import type { StreamChunk } from '@tanstack/ai'
5+
6+
function createMockConnectionAdapter(options: { chunks: StreamChunk[] }) {
7+
return stream(async function* () {
8+
for (const chunk of options.chunks) {
9+
yield chunk
10+
}
11+
})
12+
}
13+
14+
function createApprovalToolCallChunks(
15+
toolCalls: Array<{
16+
id: string
17+
name: string
18+
arguments: string
19+
approvalId: string
20+
}>,
21+
): StreamChunk[] {
22+
const chunks: StreamChunk[] = []
23+
const timestamp = Date.now()
24+
25+
// Start assistant message
26+
chunks.push({
27+
type: 'content',
28+
id: 'msg-1',
29+
model: 'test-model',
30+
timestamp,
31+
delta: '',
32+
content: '',
33+
role: 'assistant',
34+
})
35+
36+
for (const toolCall of toolCalls) {
37+
// 1. Tool Call Chunk
38+
chunks.push({
39+
type: 'tool_call',
40+
id: 'msg-1',
41+
model: 'test-model',
42+
timestamp,
43+
toolCall: {
44+
id: toolCall.id,
45+
type: 'function',
46+
function: {
47+
name: toolCall.name,
48+
arguments: toolCall.arguments,
49+
},
50+
},
51+
index: 0,
52+
})
53+
54+
// 2. Approval Requested Chunk
55+
chunks.push({
56+
type: 'approval-requested',
57+
id: 'msg-1',
58+
model: 'test-model',
59+
timestamp,
60+
toolCallId: toolCall.id,
61+
toolName: toolCall.name,
62+
input: JSON.parse(toolCall.arguments),
63+
approval: {
64+
id: toolCall.approvalId,
65+
needsApproval: true,
66+
},
67+
} as any) // Cast to any if types are not perfectly aligned yet, or use correct type
68+
}
69+
70+
// Done chunk
71+
chunks.push({
72+
type: 'done',
73+
id: 'msg-1',
74+
model: 'test-model',
75+
timestamp,
76+
finishReason: 'tool_calls',
77+
})
78+
79+
return chunks
80+
}
81+
82+
describe('ChatClient Approval Flow', () => {
83+
it('should execute client tool when approved', async () => {
84+
const toolName = 'delete_local_data'
85+
const toolCallId = 'call_123'
86+
const approvalId = 'approval_123'
87+
const input = { key: 'test-key' }
88+
89+
const chunks = createApprovalToolCallChunks([
90+
{
91+
id: toolCallId,
92+
name: toolName,
93+
arguments: JSON.stringify(input),
94+
approvalId,
95+
},
96+
])
97+
98+
const adapter = createMockConnectionAdapter({ chunks })
99+
100+
const execute = vi.fn().mockResolvedValue({ deleted: true })
101+
const clientTool = {
102+
name: toolName,
103+
description: 'Delete data',
104+
execute,
105+
}
106+
107+
const client = new ChatClient({
108+
connection: adapter,
109+
tools: [clientTool],
110+
})
111+
112+
// Start the flow
113+
await client.sendMessage('Delete data')
114+
115+
// Wait for stream to finish (approval request should be pending)
116+
await new Promise((resolve) => setTimeout(resolve, 100))
117+
118+
// Verify tool execution hasn't happened yet
119+
expect(execute).not.toHaveBeenCalled()
120+
121+
// Approve the tool
122+
await client.addToolApprovalResponse({
123+
id: approvalId,
124+
approved: true,
125+
})
126+
127+
// Wait for execution (this is where it currently hangs/fails)
128+
await new Promise((resolve) => setTimeout(resolve, 100))
129+
130+
// Expect execute to have been called
131+
expect(execute).toHaveBeenCalledWith(input)
132+
})
133+
})

packages/typescript/ai-client/tests/connection-adapters.test.ts

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,107 @@ describe('connection-adapters', () => {
318318
expect(body.data).toEqual({ key: 'value' })
319319
})
320320

321+
it('should send model messages by default', async () => {
322+
const mockReader = {
323+
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
324+
releaseLock: vi.fn(),
325+
}
326+
327+
const mockResponse = {
328+
ok: true,
329+
body: {
330+
getReader: () => mockReader,
331+
},
332+
}
333+
334+
fetchMock.mockResolvedValue(mockResponse as any)
335+
336+
const adapter = fetchServerSentEvents('/api/chat')
337+
338+
for await (const _ of adapter.connect([
339+
{
340+
id: 'msg_1',
341+
role: 'assistant',
342+
parts: [
343+
{
344+
type: 'tool-call',
345+
id: 'tool_1',
346+
name: 'testTool',
347+
arguments: '{}',
348+
state: 'approval-responded',
349+
approval: { id: 'approval_tool_1', needsApproval: true, approved: true },
350+
},
351+
],
352+
createdAt: new Date(),
353+
},
354+
] as any)) {
355+
// Consume
356+
}
357+
358+
const call = fetchMock.mock.calls[0]
359+
const body = JSON.parse(call?.[1]?.body as string)
360+
expect(body.messages[0]).not.toHaveProperty('parts')
361+
expect(body.messages[0]).toMatchObject({
362+
role: 'assistant',
363+
})
364+
})
365+
366+
it('should send full UI messages when configured', async () => {
367+
const mockReader = {
368+
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
369+
releaseLock: vi.fn(),
370+
}
371+
372+
const mockResponse = {
373+
ok: true,
374+
body: {
375+
getReader: () => mockReader,
376+
},
377+
}
378+
379+
fetchMock.mockResolvedValue(mockResponse as any)
380+
381+
const adapter = fetchServerSentEvents('/api/chat', {
382+
sendFullMessages: true,
383+
})
384+
385+
const uiMessages = [
386+
{
387+
id: 'msg_1',
388+
role: 'assistant',
389+
parts: [
390+
{
391+
type: 'tool-call',
392+
id: 'tool_1',
393+
name: 'testTool',
394+
arguments: '{}',
395+
state: 'approval-responded',
396+
approval: { id: 'approval_tool_1', needsApproval: true, approved: true },
397+
},
398+
],
399+
createdAt: new Date(),
400+
},
401+
]
402+
403+
for await (const _ of adapter.connect(uiMessages as any)) {
404+
// Consume
405+
}
406+
407+
const call = fetchMock.mock.calls[0]
408+
const body = JSON.parse(call?.[1]?.body as string)
409+
expect(body.messages[0]).toHaveProperty('parts')
410+
expect(body.messages[0]).toMatchObject({
411+
role: 'assistant',
412+
parts: [
413+
{
414+
type: 'tool-call',
415+
id: 'tool_1',
416+
approval: { id: 'approval_tool_1', approved: true },
417+
},
418+
],
419+
})
420+
})
421+
321422
it('should use custom fetchClient when provided', async () => {
322423
const customFetch = vi.fn()
323424
const mockReader = {

0 commit comments

Comments
 (0)