Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/add-custom-fetch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"workers-ai-provider": patch
---

Add optional `fetch` parameter to credentials mode for request interception and testing. Available when using `accountId + apiKey` (not with bindings). Matches the pattern used by `@ai-sdk/openai` and `@ai-sdk/anthropic`.
4 changes: 2 additions & 2 deletions demos/mcp-server-bearer-auth/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ export const renderApproveContent = async (
<div class="max-w-md mx-auto bg-white p-8 rounded-lg shadow-md text-center">
<div class="mb-4">
<span class="inline-block p-3 ${status === "success"
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
${status === "success" ? "✓" : "✗"}
</span>
</div>
Expand Down
6 changes: 5 additions & 1 deletion demos/remote-mcp-cf-access/src/access-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ export async function handleAccessRequest(
let codeVerifier: string;

try {
const result = await validateOAuthState(request, env.OAUTH_KV, env.COOKIE_ENCRYPTION_KEY);
const result = await validateOAuthState(
request,
env.OAUTH_KV,
env.COOKIE_ENCRYPTION_KEY,
);
oauthReqInfo = result.oauthReqInfo;
codeVerifier = result.codeVerifier;
} catch (error: any) {
Expand Down
1 change: 0 additions & 1 deletion demos/remote-mcp-cf-access/src/workers-oauth-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ async function generatePKCE(): Promise<{ codeVerifier: string; codeChallenge: st
return { codeVerifier, codeChallenge };
}


async function getApprovedClientsFromCookie(
request: Request,
cookieSecret: string,
Expand Down
4 changes: 2 additions & 2 deletions demos/remote-mcp-server-autorag/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ export const renderApproveContent = async (
<div class="max-w-md mx-auto bg-white p-8 rounded-lg shadow-md text-center">
<div class="mb-4">
<span class="inline-block p-3 ${status === "success"
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
${status === "success" ? "✓" : "✗"}
</span>
</div>
Expand Down
4 changes: 2 additions & 2 deletions demos/remote-mcp-server/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ export const renderApproveContent = async (
<div class="max-w-md mx-auto bg-white p-8 rounded-lg shadow-md text-center">
<div class="mb-4">
<span class="inline-block p-3 ${status === "success"
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"} rounded-full">
${status === "success" ? "✓" : "✗"}
</span>
</div>
Expand Down
9 changes: 8 additions & 1 deletion packages/workers-ai-provider/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ export type WorkersAISettings = (
* Both binding must be absent if credentials are used directly.
*/
binding?: never;

/**
* Custom fetch implementation. You can use it as a middleware to
* intercept requests, or to provide a custom fetch implementation
* for e.g. testing. Only available in credentials mode.
*/
fetch?: typeof globalThis.fetch;
}
) & {
/**
Expand Down Expand Up @@ -159,7 +166,7 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
} else {
const { accountId, apiKey } = options;
binding = {
run: createRun({ accountId, apiKey }),
run: createRun({ accountId, apiKey, fetch: options.fetch }),
} as Ai;
}

Expand Down
7 changes: 5 additions & 2 deletions packages/workers-ai-provider/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export interface CreateRunConfig {
accountId: string;
/** Cloudflare API token/key with appropriate permissions. */
apiKey: string;
/** Custom fetch implementation for intercepting requests. */
fetch?: typeof globalThis.fetch;
}

/**
Expand All @@ -68,6 +70,7 @@ export interface CreateRunConfig {
*/
export function createRun(config: CreateRunConfig): AiRun {
const { accountId, apiKey } = config;
const fetchFn = config.fetch ?? globalThis.fetch;

return async function run<Name extends keyof AiModels>(
model: Name,
Expand Down Expand Up @@ -141,7 +144,7 @@ export function createRun(config: CreateRunConfig): AiRun {

const body = JSON.stringify(inputs);

const response = await fetch(url, {
const response = await fetchFn(url, {
body,
headers,
method: "POST",
Expand Down Expand Up @@ -180,7 +183,7 @@ export function createRun(config: CreateRunConfig): AiRun {
// Retry without streaming so doStream's graceful degradation path can
// wrap the complete response as a synthetic stream.
// Use the same URL (gateway or direct) as the original request.
const retryResponse = await fetch(url, {
const retryResponse = await fetchFn(url, {
body: JSON.stringify({
...(inputs as Record<string, unknown>),
stream: false,
Expand Down
49 changes: 49 additions & 0 deletions packages/workers-ai-provider/test/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,53 @@ describe("createRun", () => {
}),
);
});

it("should use custom fetch when provided", async () => {
const customFetch = vi.fn().mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
});

const run = createRun({
accountId: "test-account",
apiKey: "test-key",
fetch: customFetch as typeof globalThis.fetch,
});
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" });

expect(customFetch).toHaveBeenCalledWith(
"https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/meta/llama-3.1-8b-instruct",
expect.objectContaining({ method: "POST" }),
);
expect(globalThis.fetch).not.toHaveBeenCalled();
});

it("should use custom fetch for streaming retry fallback", async () => {
const customFetch = vi
.fn()
// First call: streaming request returns JSON instead of SSE (triggers retry)
.mockResolvedValueOnce({
ok: true,
headers: new Headers({ "content-type": "application/json" }),
body: null,
json: vi.fn().mockResolvedValue({ result: { response: "" } }),
})
// Second call: non-streaming retry
.mockResolvedValueOnce({
ok: true,
headers: new Headers({ "content-type": "application/json" }),
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
});

const run = createRun({
accountId: "test-account",
apiKey: "test-key",
fetch: customFetch as typeof globalThis.fetch,
});
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi", stream: true } as any);

expect(customFetch).toHaveBeenCalledTimes(2);
expect(globalThis.fetch).not.toHaveBeenCalled();
});
});
Loading