diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 523bb7a07b..f7da057f78 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -14,8 +14,11 @@ * limitations under the License. */ +import { GenkitError, StatusName } from '@genkit-ai/core'; +import { HasRegistry } from '@genkit-ai/core/registry'; import { Document } from '../document.js'; import { injectInstructions } from '../formats/index.js'; +import { ModelArgument } from '../index.js'; import type { MediaPart, MessageData, @@ -23,6 +26,8 @@ import type { ModelMiddleware, Part, } from '../model.js'; +import { resolveModel } from '../model.js'; + /** * Preprocess a GenerateRequest to download referenced http(s) media URLs and * inline them as data URIs. @@ -235,6 +240,216 @@ export function augmentWithContext( }; } +/** + * Options for the `retry` middleware. + */ +export interface RetryOptions { + /** + * The maximum number of times to retry a failed request. + * @default 3 + */ + maxRetries?: number; + /** + * An array of `StatusName` values that should trigger a retry. + * @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL'] + */ + statuses?: StatusName[]; + /** + * The initial delay between retries, in milliseconds. + * @default 1000 + */ + initialDelayMs?: number; + /** + * The maximum delay between retries, in milliseconds. + * @default 60000 + */ + maxDelayMs?: number; + /** + * The factor by which the delay increases after each retry (exponential backoff). + * @default 2 + */ + backoffFactor?: number; + /** + * Whether to disable jitter on the delay. Jitter adds a random factor to the + * delay to help prevent a "thundering herd" of clients all retrying at the + * same time. + * @default false + */ + noJitter?: boolean; + /** + * A callback to be executed on each retry attempt. + */ + onError?: (error: Error, attempt: number) => void; +} + +let __setTimeout: ( + callback: (...args: any[]) => void, + ms?: number +) => NodeJS.Timeout = setTimeout; + +/** + * FOR TESTING ONLY. + * @internal + */ +export const TEST_ONLY = { + setRetryTimeout( + impl: (callback: (...args: any[]) => void, ms?: number) => NodeJS.Timeout + ) { + __setTimeout = impl; + }, +}; + +const DEFAULT_RETRY_STATUSES: StatusName[] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', +]; + +const DEFAULT_FALLBACK_STATUSES: StatusName[] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', + 'NOT_FOUND', + 'UNIMPLEMENTED', +]; + +/** + * Creates a middleware that retries requests on specific error statuses. + * + * ```ts + * const { text } = await ai.generate({ + * model: googleAI.model('gemini-2.5-pro'), + * prompt: 'You are a helpful AI assistant named Walt, say hello', + * use: [ + * retry({ + * maxRetries: 2, + * initialDelayMs: 1000, + * backoffFactor: 2, + * }), + * ], + * }); + * ``` + */ +export function retry(options: RetryOptions = {}): ModelMiddleware { + const { + maxRetries = 3, + statuses = DEFAULT_RETRY_STATUSES, + initialDelayMs = 1000, + maxDelayMs = 60000, + backoffFactor = 2, + noJitter = false, + onError, + } = options; + + return async (req, next) => { + let lastError: any; + let currentDelay = initialDelayMs; + for (let i = 0; i <= maxRetries; i++) { + try { + return await next(req); + } catch (e) { + lastError = e; + const error = e as Error; + if (i < maxRetries) { + let shouldRetry = false; + if (error instanceof GenkitError) { + if (statuses.includes(error.status)) { + shouldRetry = true; + } + } else { + shouldRetry = true; + } + + if (shouldRetry) { + onError?.(error, i + 1); + let delay = currentDelay; + if (!noJitter) { + delay = delay + 1000 * Math.pow(2, i) * Math.random(); + } + await new Promise((resolve) => __setTimeout(resolve, delay)); + currentDelay = Math.min(currentDelay * backoffFactor, maxDelayMs); + continue; + } + } + throw error; + } + } + throw lastError; + }; +} + +/** + * Options for the `fallback` middleware. + */ +export interface FallbackOptions { + /** + * An array of models to try in order. + */ + models: ModelArgument[]; + /** + * An array of `StatusName` values that should trigger a fallback. + * @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL', 'NOT_FOUND', 'UNIMPLEMENTED'] + */ + statuses?: StatusName[]; + /** + * A callback to be executed on each fallback attempt. + */ + onError?: (error: Error) => void; +} + +/** + * Creates a middleware that falls back to a different model on specific error statuses. + * + * ```ts + * const { text } = await ai.generate({ + * model: googleAI.model('gemini-2.5-pro'), + * prompt: 'You are a helpful AI assistant named Walt, say hello', + * use: [ + * fallback(ai, { + * models: [googleAI.model('gemini-2.5-flash')], + * statuses: ['RESOURCE_EXHAUSTED'], + * }), + * ], + * }); + * ``` + */ +export function fallback( + ai: HasRegistry, + options: FallbackOptions +): ModelMiddleware { + const { models, statuses = DEFAULT_FALLBACK_STATUSES, onError } = options; + + return async (req, next) => { + try { + return await next(req); + } catch (e) { + if (e instanceof GenkitError && statuses.includes(e.status)) { + onError?.(e); + let lastError: any = e; + for (const model of models) { + try { + const resolved = await resolveModel(ai.registry, model); + return await resolved.modelAction(req); + } catch (e2) { + lastError = e2; + if (e2 instanceof GenkitError && statuses.includes(e2.status)) { + onError?.(e2); + continue; + } + throw e2; + } + } + throw lastError; + } + throw e; + } + }; +} + export interface SimulatedConstrainedGenerationOptions { instructionsRenderer?: (schema: Record) => string; } diff --git a/js/ai/tests/helpers.ts b/js/ai/tests/helpers.ts index 8a0ded4ea9..0b883e3dab 100644 --- a/js/ai/tests/helpers.ts +++ b/js/ai/tests/helpers.ts @@ -38,24 +38,29 @@ export type ProgrammableModel = ModelAction & { ) => Promise; lastRequest?: GenerateRequest; + requestCount: number; }; export function defineProgrammableModel( registry: Registry, - info?: ModelInfo + info?: ModelInfo, + name?: string ): ProgrammableModel { const pm = defineModel( registry, { - ...info, - name: 'programmableModel', + apiVersion: 'v2', + ...(info as any), + name: name ?? 'programmableModel', }, - async (request, streamingCallback) => { + async (request, { sendChunk }) => { + pm.requestCount++; pm.lastRequest = JSON.parse(JSON.stringify(request)); - return pm.handleResponse(request, streamingCallback); + return pm.handleResponse(request, sendChunk); } ) as ProgrammableModel; + pm.requestCount = 0; return pm; } diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 0d30801c93..65fbb2d4ea 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { z } from '@genkit-ai/core'; +import { GenkitError, z } from '@genkit-ai/core'; import { initNodeFeatures } from '@genkit-ai/core/node'; import { Registry } from '@genkit-ai/core/registry'; import * as assert from 'assert'; @@ -30,7 +30,10 @@ import { } from '../../src/model.js'; import { CONTEXT_PREFACE, + TEST_ONLY, augmentWithContext, + fallback, + retry, simulateConstrainedGeneration, simulateSystemPrompt, validateSupport, @@ -38,6 +41,8 @@ import { } from '../../src/model/middleware.js'; import { defineProgrammableModel } from '../helpers.js'; +const { setRetryTimeout } = TEST_ONLY; + initNodeFeatures(); describe('validateSupport', () => { @@ -140,6 +145,481 @@ describe('validateSupport', () => { }); }); +describe('retry', () => { + let registry: Registry; + + beforeEach(() => { + registry = new Registry(); + configureFormats(registry); + }); + + it('should not retry on success', async () => { + const pm = defineProgrammableModel(registry); + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry()], + }); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(result.text, 'success'); + }); + + it('should retry on a retryable GenkitError', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 3) { + throw new GenkitError({ status: 'UNAVAILABLE', message: 'test' }); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + setRetryTimeout((callback, ms) => { + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry({ maxRetries: 3 })], + }); + + assert.strictEqual(requestCount, 3); + assert.strictEqual(result.text, 'success'); + }); + + it('should retry on a non-GenkitError', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 2) { + throw new Error('generic error'); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + setRetryTimeout((callback, ms) => { + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry({ maxRetries: 2 })], + }); + + assert.strictEqual(requestCount, 2); + assert.strictEqual(result.text, 'success'); + }); + + it('should throw after exhausting retries', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + throw new GenkitError({ status: 'UNAVAILABLE', message: 'test' }); + }; + + setRetryTimeout((callback, ms) => { + callback(); + return 0 as any; + }); + + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry({ maxRetries: 2 })], + }), + /UNAVAILABLE: test/ + ); + + assert.strictEqual(requestCount, 3); + }); + + it('should call onError callback', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + throw new Error('test error'); + }; + + setRetryTimeout((callback, ms) => { + callback(); + return 0 as any; + }); + + let errorCount = 0; + let lastError: Error | undefined; + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [ + retry({ + maxRetries: 2, + onError: (err, attempt) => { + errorCount++; + lastError = err; + assert.strictEqual(attempt, errorCount); + }, + }), + ], + }), + /test error/ + ); + + assert.strictEqual(requestCount, 3); + assert.strictEqual(errorCount, 2); + assert.ok(lastError); + assert.strictEqual(lastError!.message, 'test error'); + }); + + it('should not retry on non-retryable status', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + throw new GenkitError({ status: 'INVALID_ARGUMENT', message: 'test' }); + }; + + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry({ maxRetries: 2 })], + }), + /INVALID_ARGUMENT: test/ + ); + + assert.strictEqual(requestCount, 1); + }); + + it('should respect initial delay', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 2) { + throw new Error('generic error'); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + let totalDelay = 0; + setRetryTimeout((callback, ms) => { + totalDelay += ms!; + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [retry({ maxRetries: 2, initialDelayMs: 50, noJitter: true })], + }); + + assert.strictEqual(requestCount, 2); + assert.strictEqual(result.text, 'success'); + assert.strictEqual(totalDelay, 50); + }); + + it('should respect backoff factor', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 3) { + throw new Error('generic error'); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + let totalDelay = 0; + setRetryTimeout((callback, ms) => { + totalDelay += ms!; + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [ + retry({ + maxRetries: 3, + initialDelayMs: 20, + backoffFactor: 2, + noJitter: true, + }), + ], + }); + + assert.strictEqual(requestCount, 3); + assert.strictEqual(result.text, 'success'); + assert.strictEqual(totalDelay, 20 + 40); + }); + + it('should apply jitter', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 2) { + throw new Error('generic error'); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + let totalDelay = 0; + setRetryTimeout((callback, ms) => { + totalDelay += ms!; + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [ + retry({ + maxRetries: 2, + initialDelayMs: 50, + noJitter: false, // do jitter + }), + ], + }); + + assert.strictEqual(requestCount, 2); + assert.strictEqual(result.text, 'success'); + assert.ok(totalDelay >= 50); + assert.ok(totalDelay <= 1050); + }); + + it('should respect max delay', async () => { + const pm = defineProgrammableModel(registry); + let requestCount = 0; + pm.handleResponse = async (req, sc) => { + requestCount++; + if (requestCount < 3) { + throw new Error('generic error'); + } + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + let totalDelay = 0; + setRetryTimeout((callback, ms) => { + totalDelay += ms!; + callback(); + return 0 as any; + }); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [ + retry({ + maxRetries: 3, + initialDelayMs: 20, + backoffFactor: 2, + maxDelayMs: 30, + noJitter: true, + }), + ], + }); + + assert.strictEqual(requestCount, 3); + assert.strictEqual(result.text, 'success'); + assert.strictEqual(totalDelay, 20 + 30); + }); +}); + +describe('fallback', () => { + let registry: Registry; + + beforeEach(() => { + registry = new Registry(); + configureFormats(registry); + }); + + it('should not fallback on success', async () => { + const pm = defineProgrammableModel(registry, {}, 'programmableModel'); + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [{ text: 'success' }], + }, + }; + }; + + const fallbackPm = defineProgrammableModel(registry, {}, 'fallbackModel'); + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [fallback({ registry }, { models: ['fallbackModel'] })], + }); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(fallbackPm.requestCount, 0); + }); + + it('should call onError callback', async () => { + const pm = defineProgrammableModel(registry, {}, 'programmableModel'); + pm.handleResponse = async () => { + throw new GenkitError({ status: 'UNAVAILABLE', message: 'test' }); + }; + + const fallbackPm = defineProgrammableModel(registry, {}, 'fallbackModel'); + fallbackPm.handleResponse = async () => { + throw new GenkitError({ status: 'INTERNAL', message: 'fallback fail' }); + }; + + let errorCount = 0; + let lastError: Error | undefined; + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [ + fallback({ registry } as any, { + models: ['fallbackModel'], + onError: (err) => { + errorCount++; + lastError = err; + }, + }), + ], + }), + /INTERNAL: fallback fail/ + ); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(fallbackPm.requestCount, 1); + assert.strictEqual(errorCount, 2); + assert.ok(lastError); + assert.strictEqual(lastError!.message, 'INTERNAL: fallback fail'); + }); + + it('should fallback on a fallbackable error', async () => { + const pm = defineProgrammableModel(registry, {}, 'programmableModel'); + pm.handleResponse = async () => { + throw new GenkitError({ status: 'UNAVAILABLE', message: 'test' }); + }; + + const fallbackPm = defineProgrammableModel(registry, {}, 'fallbackModel'); + fallbackPm.handleResponse = async () => { + return { + message: { + role: 'model', + content: [{ text: 'fallback success' }], + }, + }; + }; + + const result = await generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [fallback({ registry }, { models: ['fallbackModel'] })], + }); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(fallbackPm.requestCount, 1); + assert.strictEqual(result.text, 'fallback success'); + }); + + it('should throw after all fallbacks fail', async () => { + const pm = defineProgrammableModel(registry, {}, 'programmableModel'); + pm.handleResponse = async (req, sc) => { + throw new GenkitError({ status: 'UNAVAILABLE', message: 'test' }); + }; + + const fallbackPm = defineProgrammableModel(registry, {}, 'fallbackModel'); + fallbackPm.handleResponse = async (req, sc) => { + throw new GenkitError({ status: 'INTERNAL', message: 'fallback fail' }); + }; + + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [fallback({ registry }, { models: ['fallbackModel'] })], + }), + /INTERNAL: fallback fail/ + ); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(fallbackPm.requestCount, 1); + }); + + it('should not fallback on non-fallbackable error', async () => { + const pm = defineProgrammableModel(registry, {}, 'programmableModel'); + pm.handleResponse = async (req, sc) => { + throw new GenkitError({ status: 'INVALID_ARGUMENT', message: 'test' }); + }; + + const fallbackPm = defineProgrammableModel(registry, {}, 'fallbackModel'); + + await assert.rejects( + generate(registry, { + model: 'programmableModel', + prompt: 'test', + use: [fallback({ registry }, { models: ['fallbackModel'] })], + }), + /INVALID_ARGUMENT: test/ + ); + + assert.strictEqual(pm.requestCount, 1); + assert.strictEqual(fallbackPm.requestCount, 0); + }); +}); + const registry = new Registry(); configureFormats(registry); diff --git a/js/genkit/README.md b/js/genkit/README.md index 7cd9ee26a8..b251533c3e 100644 --- a/js/genkit/README.md +++ b/js/genkit/README.md @@ -26,7 +26,25 @@ const ai = genkit({ plugins: [googleAI()] }); const { text } = await ai.generate({ model: googleAI.model('gemini-2.5-flash'), - prompt: 'Why is Firebase awesome?' + prompt: 'Why is Genkit awesome?' +}); +``` + +Genkit also provides middleware to add common functionality to your AI requests. For example, you can use the `retry` middleware to automatically retry failed requests: + +```ts +import { retry } from 'genkit/model/middleware'; + +const { text } = await ai.generate({ + model: googleAI.model('gemini-2.5-flash'), + prompt: 'Why is Genkit awesome?', + use: [ + retry({ + maxRetries: 3, + initialDelayMs: 1000, + backoffFactor: 2, + }), + ], }); ``` diff --git a/js/genkit/src/middleware.ts b/js/genkit/src/middleware.ts index ff31561e2d..f7d6e9f287 100644 --- a/js/genkit/src/middleware.ts +++ b/js/genkit/src/middleware.ts @@ -17,6 +17,8 @@ export { augmentWithContext, downloadRequestMedia, + fallback, + retry, simulateSystemPrompt, validateSupport, type AugmentWithContextOptions, diff --git a/js/plugins/compat-oai/src/model.ts b/js/plugins/compat-oai/src/model.ts index d909bb5f42..1544d5d0c6 100644 --- a/js/plugins/compat-oai/src/model.ts +++ b/js/plugins/compat-oai/src/model.ts @@ -26,10 +26,17 @@ import type { StreamingCallback, ToolRequestPart, } from 'genkit'; -import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit'; +import { + GenerationCommonConfigSchema, + GenkitError, + Message, + StatusName, + modelRef, + z, +} from 'genkit'; import type { ModelAction, ModelInfo, ToolDefinition } from 'genkit/model'; import { model } from 'genkit/plugin'; -import type OpenAI from 'openai'; +import OpenAI, { APIError } from 'openai'; import type { ChatCompletion, ChatCompletionChunk, @@ -459,50 +466,75 @@ export function openAIModelRunner( abortSignal?: AbortSignal; } ): Promise => { - let response: ChatCompletion; - const body = toOpenAIRequestBody(name, request, requestBuilder); - if (options?.streamingRequested) { - const stream = client.beta.chat.completions.stream( - { - ...body, - stream: true, - stream_options: { - include_usage: true, + try { + let response: ChatCompletion; + const body = toOpenAIRequestBody(name, request, requestBuilder); + if (options?.streamingRequested) { + const stream = client.beta.chat.completions.stream( + { + ...body, + stream: true, + stream_options: { + include_usage: true, + }, }, - }, - { signal: options?.abortSignal } - ); - for await (const chunk of stream) { - chunk.choices?.forEach((chunk) => { - const c = fromOpenAIChunkChoice(chunk); - options?.sendChunk!({ - index: chunk.index, - content: c.message?.content ?? [], + { signal: options?.abortSignal } + ); + for await (const chunk of stream) { + chunk.choices?.forEach((chunk) => { + const c = fromOpenAIChunkChoice(chunk); + options?.sendChunk!({ + index: chunk.index, + content: c.message?.content ?? [], + }); }); + } + response = await stream.finalChatCompletion(); + } else { + response = await client.chat.completions.create(body, { + signal: options?.abortSignal, }); } - response = await stream.finalChatCompletion(); - } else { - response = await client.chat.completions.create(body, { - signal: options?.abortSignal, - }); - } - const standardResponse: GenerateResponseData = { - usage: { - inputTokens: response.usage?.prompt_tokens, - outputTokens: response.usage?.completion_tokens, - totalTokens: response.usage?.total_tokens, - }, - raw: response, - }; - if (response.choices.length === 0) { - return standardResponse; - } else { - const choice = response.choices[0]; - return { - ...fromOpenAIChoice(choice, request.output?.format === 'json'), - ...standardResponse, + const standardResponse: GenerateResponseData = { + usage: { + inputTokens: response.usage?.prompt_tokens, + outputTokens: response.usage?.completion_tokens, + totalTokens: response.usage?.total_tokens, + }, + raw: response, }; + if (response.choices.length === 0) { + return standardResponse; + } else { + const choice = response.choices[0]; + return { + ...fromOpenAIChoice(choice, request.output?.format === 'json'), + ...standardResponse, + }; + } + } catch (e) { + if (e instanceof APIError) { + let status: StatusName = 'UNKNOWN'; + switch (e.status) { + case 429: + status = 'RESOURCE_EXHAUSTED'; + break; + case 400: + status = 'INVALID_ARGUMENT'; + break; + case 500: + status = 'INTERNAL'; + break; + case 503: + status = 'UNAVAILABLE'; + break; + } + throw new GenkitError({ + status, + message: e.message, + }); + } + throw e; } }; } diff --git a/js/plugins/compat-oai/tests/compat_oai_test.ts b/js/plugins/compat-oai/tests/compat_oai_test.ts index c89630bf22..04c3a7cacb 100644 --- a/js/plugins/compat-oai/tests/compat_oai_test.ts +++ b/js/plugins/compat-oai/tests/compat_oai_test.ts @@ -31,6 +31,7 @@ import type { ChatCompletionRole, } from 'openai/resources/index.mjs'; +import { APIError } from 'openai'; import { fromOpenAIChoice, fromOpenAIChunkChoice, @@ -43,10 +44,16 @@ import { toOpenAITextAndMedia, } from '../src/model'; -jest.mock('@genkit-ai/ai/model', () => ({ - ...(jest.requireActual('@genkit-ai/ai/model') as Record), - defineModel: jest.fn(), -})); +jest.mock('genkit/model', () => { + const originalModule = + jest.requireActual('genkit/model'); + return { + ...originalModule, + defineModel: jest.fn((_, runner) => { + return runner; + }), + }; +}); describe('toOpenAIRole', () => { const testCases: { @@ -1505,4 +1512,92 @@ describe('openAIModelRunner', () => { { signal: undefined } ); }); + + describe('error handling', () => { + const testCases = [ + { + name: '429', + error: new APIError( + 429, + { error: { message: 'Rate limit exceeded' } }, + '', + {} + ), + expectedStatus: 'RESOURCE_EXHAUSTED', + }, + { + name: '400', + error: new APIError( + 400, + { error: { message: 'Invalid request' } }, + '', + {} + ), + expectedStatus: 'INVALID_ARGUMENT', + }, + { + name: '500', + error: new APIError( + 500, + { error: { message: 'Internal server error' } }, + '', + {} + ), + expectedStatus: 'INTERNAL', + }, + { + name: '503', + error: new APIError( + 503, + { error: { message: 'Service unavailable' } }, + '', + {} + ), + expectedStatus: 'UNAVAILABLE', + }, + ]; + + for (const tc of testCases) { + it(`should convert ${tc.name} error to GenkitError`, async () => { + const openaiClient = { + chat: { + completions: { + create: jest.fn(async () => { + throw tc.error; + }), + }, + }, + }; + const runner = openAIModelRunner( + 'gpt-4o', + openaiClient as unknown as OpenAI + ); + await expect(runner({ messages: [] })).rejects.toThrow( + expect.objectContaining({ + status: tc.expectedStatus, + }) + ); + }); + } + + it('should re-throw non-APIError', async () => { + const error = new Error('Some other error'); + const openaiClient = { + chat: { + completions: { + create: jest.fn(async () => { + throw error; + }), + }, + }, + }; + const runner = openAIModelRunner( + 'gpt-4o', + openaiClient as unknown as OpenAI + ); + await expect(runner({ messages: [] })).rejects.toThrow( + 'Some other error' + ); + }); + }); }); diff --git a/js/plugins/google-genai/src/googleai/client.ts b/js/plugins/google-genai/src/googleai/client.ts index cb97710be0..c8b406e36a 100644 --- a/js/plugins/google-genai/src/googleai/client.ts +++ b/js/plugins/google-genai/src/googleai/client.ts @@ -14,6 +14,8 @@ * limitations under the License. */ +import { GenkitError, StatusName } from 'genkit'; +import { logger } from 'genkit/logging'; import { extractErrMsg, getGenkitClientHeader, @@ -352,13 +354,32 @@ async function makeRequest( } catch (e) { // Not JSON or expected format, use the raw text } - throw new Error( - `Error fetching from ${url}: [${response.status} ${response.statusText}] ${errorMessage}` - ); + let status: StatusName = 'UNKNOWN'; + switch (response.status) { + case 429: + status = 'RESOURCE_EXHAUSTED'; + break; + case 400: + status = 'INVALID_ARGUMENT'; + break; + case 500: + status = 'INTERNAL'; + break; + case 503: + status = 'UNAVAILABLE'; + break; + } + throw new GenkitError({ + status, + message: `Error fetching from ${url}: [${response.status} ${response.statusText}] ${errorMessage}`, + }); } return response; } catch (e: unknown) { - console.error(e); + logger.error(e); + if (e instanceof GenkitError) { + throw e; + } throw new Error(`Failed to fetch from ${url}: ${extractErrMsg(e)}`); } } diff --git a/js/plugins/google-genai/src/vertexai/client.ts b/js/plugins/google-genai/src/vertexai/client.ts index 20f7a2aa3c..eaa40e6b86 100644 --- a/js/plugins/google-genai/src/vertexai/client.ts +++ b/js/plugins/google-genai/src/vertexai/client.ts @@ -14,6 +14,8 @@ * limitations under the License. */ +import { GenkitError, StatusName } from 'genkit'; +import { logger } from 'genkit/logging'; import { GoogleAuth } from 'google-auth-library'; import { extractErrMsg, @@ -369,13 +371,32 @@ async function makeRequest( } catch (e) { // Not JSON or expected format, use the raw text } - throw new Error( - `Error fetching from ${url}: [${response.status} ${response.statusText}] ${errorMessage}` - ); + let status: StatusName = 'UNKNOWN'; + switch (response.status) { + case 429: + status = 'RESOURCE_EXHAUSTED'; + break; + case 400: + status = 'INVALID_ARGUMENT'; + break; + case 500: + status = 'INTERNAL'; + break; + case 503: + status = 'UNAVAILABLE'; + break; + } + throw new GenkitError({ + status, + message: `Error fetching from ${url}: [${response.status} ${response.statusText}] ${errorMessage}`, + }); } return response; } catch (e: unknown) { - console.error(e); + logger.error(e); + if (e instanceof GenkitError) { + throw e; + } throw new Error(`Failed to fetch from ${url}: ${extractErrMsg(e)}`); } } diff --git a/js/plugins/google-genai/tests/googleai/client_test.ts b/js/plugins/google-genai/tests/googleai/client_test.ts index 44100f97ca..bd322017b5 100644 --- a/js/plugins/google-genai/tests/googleai/client_test.ts +++ b/js/plugins/google-genai/tests/googleai/client_test.ts @@ -188,10 +188,15 @@ describe('Google AI Client', () => { const errorResponse = { error: { message: 'Internal Error' } }; mockFetchResponse(errorResponse, false, 500, 'Internal Server Error'); - await assert.rejects( - listModels(apiKey), - /Failed to fetch from .* Error fetching from .* \[500 Internal Server Error\] Internal Error/ - ); + await assert.rejects(listModels(apiKey), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INTERNAL'); + assert.match( + err.message, + /Error fetching from .* \[500 Internal Server Error\] Internal Error/ + ); + return true; + }); }); it('should throw an error if fetch fails with non-JSON error', async () => { @@ -203,19 +208,41 @@ describe('Google AI Client', () => { 'text/html' ); - await assert.rejects( - listModels(apiKey), - /Failed to fetch from .* Error fetching from .* \[500 Internal Server Error\]

Server Error<\/h1><\/body><\/html>/ - ); + await assert.rejects(listModels(apiKey), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INTERNAL'); + assert.match( + err.message, + /Error fetching from .* \[500 Internal Server Error\]

Server Error<\/h1><\/body><\/html>/ + ); + return true; + }); }); it('should throw an error if fetch fails with empty response body', async () => { mockFetchResponse(null, false, 502, 'Bad Gateway'); - await assert.rejects( - listModels(apiKey), - /Failed to fetch from .* Error fetching from .* \[502 Bad Gateway\] $/ - ); + await assert.rejects(listModels(apiKey), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'UNKNOWN'); + assert.match(err.message, /Error fetching from .* \[502 Bad Gateway\]/); + return true; + }); + }); + + it('should throw a resource exhausted error on 429', async () => { + const errorResponse = { error: { message: 'Too many requests' } }; + mockFetchResponse(errorResponse, false, 429, 'Too Many Requests'); + + await assert.rejects(listModels(apiKey), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'RESOURCE_EXHAUSTED'); + assert.match( + err.message, + /Error fetching from .* \[429 Too Many Requests\] Too many requests/ + ); + return true; + }); }); it('should throw an error on network failure', async () => { @@ -276,7 +303,15 @@ describe('Google AI Client', () => { await assert.rejects( generateContent(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[400 Bad Request\] Invalid Request/ + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.match( + err.message, + /Error fetching from .* \[400 Bad Request\] Invalid Request/ + ); + return true; + } ); }); @@ -285,7 +320,15 @@ describe('Google AI Client', () => { await assert.rejects( generateContent(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[400 Bad Request\] Bad Request/ + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.match( + err.message, + /Error fetching from .* \[400 Bad Request\] Bad Request/ + ); + return true; + } ); }); @@ -329,10 +372,15 @@ describe('Google AI Client', () => { const errorResponse = { error: { message: 'Embedding failed' } }; mockFetchResponse(errorResponse, false, 500, 'Internal Server Error'); - await assert.rejects( - embedContent(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[500 Internal Server Error\] Embedding failed/ - ); + await assert.rejects(embedContent(apiKey, model, request), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INTERNAL'); + assert.match( + err.message, + /Error fetching from .* \[500 Internal Server Error\] Embedding failed/ + ); + return true; + }); }); it('should throw on API error with non-JSON body', async () => { @@ -344,10 +392,15 @@ describe('Google AI Client', () => { 'text/plain' ); - await assert.rejects( - embedContent(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[500 Internal Server Error\] Internal Server Error/ - ); + await assert.rejects(embedContent(apiKey, model, request), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INTERNAL'); + assert.match( + err.message, + /Error fetching from .* \[500 Internal Server Error\] Internal Server Error/ + ); + return true; + }); }); it('should throw on network failure', async () => { @@ -677,7 +730,15 @@ describe('Google AI Client', () => { await assert.rejects( imagenPredict(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[400 Bad Request\] Imagen failed/ + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.match( + err.message, + /Error fetching from .* \[400 Bad Request\] Imagen failed/ + ); + return true; + } ); }); @@ -723,10 +784,15 @@ describe('Google AI Client', () => { const errorResponse = { error: { message: 'Veo failed to start' } }; mockFetchResponse(errorResponse, false, 500, 'Internal Server Error'); - await assert.rejects( - veoPredict(apiKey, model, request), - /Failed to fetch from .* Error fetching from .* \[500 Internal Server Error\] Veo failed to start/ - ); + await assert.rejects(veoPredict(apiKey, model, request), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INTERNAL'); + assert.match( + err.message, + /Error fetching from .* \[500 Internal Server Error\] Veo failed to start/ + ); + return true; + }); }); }); @@ -765,7 +831,15 @@ describe('Google AI Client', () => { await assert.rejects( veoCheckOperation(apiKey, operationName), - /Failed to fetch from .* Error fetching from .* \[404 Not Found\] Operation not found/ + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'UNKNOWN'); + assert.match( + err.message, + /Error fetching from .* \[404 Not Found\] Operation not found/ + ); + return true; + } ); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/client_test.ts b/js/plugins/google-genai/tests/vertexai/client_test.ts index bbe78972d0..ea5d2e485f 100644 --- a/js/plugins/google-genai/tests/vertexai/client_test.ts +++ b/js/plugins/google-genai/tests/vertexai/client_test.ts @@ -577,7 +577,15 @@ describe('Vertex AI Client', () => { await assert.rejects( generateContent(model, request, currentOptions), - /Failed to fetch from .* \[403 Forbidden\] Permission denied/ + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'UNKNOWN'); + assert.match( + err.message, + /Error fetching from .* \[403 Forbidden\] Permission denied/ + ); + return true; + } ); }); }); @@ -822,19 +830,26 @@ describe('Vertex AI Client', () => { 'text/html' ); - await assert.rejects( - listModels(regionalClientOptions), - /Failed to fetch from .* \[504 Gateway Timeout\]

Gateway Timeout<\/h1>/ - ); + await assert.rejects(listModels(regionalClientOptions), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'UNKNOWN'); + assert.match( + err.message, + /Error fetching from .* \[504 Gateway Timeout\]

Gateway Timeout<\/h1>/ + ); + return true; + }); }); it('listModels should throw an error if fetch fails with empty response body', async () => { mockFetchResponse(null, false, 502, 'Bad Gateway'); - await assert.rejects( - listModels(regionalClientOptions), - /Failed to fetch from .* \[502 Bad Gateway\] $/ - ); + await assert.rejects(listModels(regionalClientOptions), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'UNKNOWN'); + assert.match(err.message, /Error fetching from .* \[502 Bad Gateway\]/); + return true; + }); }); it('listModels should throw an error on network failure', async () => { @@ -844,6 +859,21 @@ describe('Vertex AI Client', () => { /Failed to fetch from .* Network Error/ ); }); + + it('should throw a resource exhausted error on 429', async () => { + const errorResponse = { error: { message: 'Too many requests' } }; + mockFetchResponse(errorResponse, false, 429, 'Too Many Requests'); + + await assert.rejects(listModels(regionalClientOptions), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'RESOURCE_EXHAUSTED'); + assert.match( + err.message, + /Error fetching from .* \[429 Too Many Requests\] Too many requests/ + ); + return true; + }); + }); }); describe('generateContentStream full aggregation tests', () => { diff --git a/js/plugins/google-genai/tests/vertexai/imagen_test.ts b/js/plugins/google-genai/tests/vertexai/imagen_test.ts index c37ef92dfe..df930ff30b 100644 --- a/js/plugins/google-genai/tests/vertexai/imagen_test.ts +++ b/js/plugins/google-genai/tests/vertexai/imagen_test.ts @@ -253,12 +253,11 @@ describe('Vertex AI Imagen', () => { fetchStub.rejects(error); const modelRunner = captureModelRunner(clientOptions); - await assert.rejects( - modelRunner(request, {}), - new RegExp( - `^Error: Failed to fetch from ${escapeRegExp(expectedUrl)}: Network Error` - ) - ); + await assert.rejects(modelRunner(request, {}), (err: any) => { + assert.strictEqual(err.name, 'Error'); + assert.match(err.message, /Network Error/); + return true; + }); }); it(`should handle API error response for ${clientOptions.kind}`, async () => { @@ -270,13 +269,35 @@ describe('Vertex AI Imagen', () => { mockFetchResponse(errorBody, 400); const modelRunner = captureModelRunner(clientOptions); - let expectedUrlRegex = escapeRegExp(expectedUrl); - await assert.rejects( - modelRunner(request, {}), - new RegExp( - `^Error: Failed to fetch from ${expectedUrlRegex}: Error fetching from ${expectedUrlRegex}: \\[400 Error\\] ${errorMsg}` - ) - ); + await assert.rejects(modelRunner(request, {}), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.match( + err.message, + /Error fetching from .* \[400 Error\] Invalid argument/ + ); + return true; + }); + }); + + it(`should throw a resource exhausted error on 429 for ${clientOptions.kind}`, async () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A bird' }] }], + }; + const errorMsg = 'Too many requests'; + const errorBody = { error: { message: errorMsg, code: 429 } }; + mockFetchResponse(errorBody, 429); + + const modelRunner = captureModelRunner(clientOptions); + await assert.rejects(modelRunner(request, {}), (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.strictEqual(err.status, 'RESOURCE_EXHAUSTED'); + assert.match( + err.message, + /Error fetching from .* \[429 Error\] Too many requests/ + ); + return true; + }); }); } diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index b810ab1017..babe783c10 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -24,6 +24,7 @@ import { type Part, type StreamingCallback, } from 'genkit'; +import { fallback, retry } from 'genkit/model/middleware'; import { Readable } from 'stream'; import wav from 'wav'; @@ -43,6 +44,36 @@ ai.defineFlow('basic-hi', async () => { return text; }); +ai.defineFlow('basic-hi-with-retry', async () => { + const { text } = await ai.generate({ + model: googleAI.model('gemini-2.5-pro'), + prompt: 'You are a helpful AI assistant named Walt, say hello', + use: [ + retry({ + maxRetries: 2, + onError: (e, attempt) => console.log('--- oops ', attempt, e), + }), + ], + }); + + return text; +}); + +ai.defineFlow('basic-hi-with-fallback', async () => { + const { text } = await ai.generate({ + model: googleAI.model('gemini-2.5-soemthing-that-does-not-exist'), + prompt: 'You are a helpful AI assistant named Walt, say hello', + use: [ + fallback(ai, { + models: [googleAI.model('gemini-2.5-flash')], + statuses: ['UNKNOWN'], + }), + ], + }); + + return text; +}); + // Multimodal input ai.defineFlow('multimodal-input', async () => { const photoBase64 = fs.readFileSync('photo.jpg', { encoding: 'base64' });