Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions js/ai/src/model/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
* limitations under the License.
*/

import { GenkitError, StatusName } from '@genkit-ai/core';
import { HasRegistry } from '@genkit-ai/core/registry';
import { Document } from '../document.js';
import { injectInstructions } from '../formats/index.js';
import { ModelArgument } from '../index.js';
import type {
MediaPart,
MessageData,
ModelInfo,
ModelMiddleware,
Part,
} from '../model.js';
import { resolveModel } from '../model.js';

/**
* Preprocess a GenerateRequest to download referenced http(s) media URLs and
* inline them as data URIs.
Expand Down Expand Up @@ -235,6 +240,202 @@ export function augmentWithContext(
};
}

/**
* Options for the `retry` middleware.
*/
export interface RetryOptions {
/**
* The maximum number of times to retry a failed request.
* @default 3
*/
maxRetries?: number;
/**
* An array of `StatusName` values that should trigger a retry.
* @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL']
*/
statuses?: StatusName[];
/**
* The initial delay between retries, in milliseconds.
* @default 1000
*/
initialDelayMs?: number;
/**
* The maximum delay between retries, in milliseconds.
* @default 60000
*/
maxDelayMs?: number;
/**
* The factor by which the delay increases after each retry (exponential backoff).
* @default 2
*/
backoffFactor?: number;
/**
* Whether to disable jitter on the delay.
* @default false
*/
noJitter?: boolean;
/**
* A callback to be executed on each retry attempt.
*/
onError?: (error: Error, attempt: number) => void;
}

let __setTimeout: (
callback: (...args: any[]) => void,
ms?: number
) => NodeJS.Timeout = setTimeout;

/**
* FOR TESTING ONLY.
* @internal
*/
export function __setRetryTimeout(
impl: (callback: (...args: any[]) => void, ms?: number) => NodeJS.Timeout
) {
__setTimeout = impl;
}

const DEFAULT_RETRY_STATUSES: StatusName[] = [
'UNAVAILABLE',
'DEADLINE_EXCEEDED',
'RESOURCE_EXHAUSTED',
'ABORTED',
'INTERNAL',
];

/**
* Creates a middleware that retries requests on specific error statuses.
*
* ```ts
* const { text } = await ai.generate({
* model: googleAI.model('gemini-2.5-pro'),
* prompt: 'You are a helpful AI assistant named Walt, say hello',
* use: [
* retry({
* maxRetries: 2,
* initialDelayMs: 1000,
* backoffFactor: 2,
* }),
* ],
* });
* ```
*/
export function retry(options: RetryOptions = {}): ModelMiddleware {
const {
maxRetries = 3,
statuses = DEFAULT_RETRY_STATUSES,
initialDelayMs = 1000,
maxDelayMs = 60000,
backoffFactor = 2,
noJitter = false,
onError,
} = options;

return async (req, next) => {
let lastError: any;
let currentDelay = initialDelayMs;
for (let i = 0; i <= maxRetries; i++) {
try {
return await next(req);
} catch (e) {
lastError = e;
const error = e as Error;
if (i < maxRetries) {
let shouldRetry = false;
if (error instanceof GenkitError) {
if (statuses.includes(error.status)) {
shouldRetry = true;
}
} else {
shouldRetry = true;
}

if (shouldRetry) {
onError?.(error, i + 1);
let delay = currentDelay;
if (!noJitter) {
delay = delay * (1 + Math.random());
}
await new Promise((resolve) => __setTimeout(resolve, delay));
currentDelay = Math.min(currentDelay * backoffFactor, maxDelayMs);
continue;
}
}
throw error;
}
}
throw lastError;
};
}

/**
* Options for the `fallback` middleware.
*/
export interface FallbackOptions {
/**
* An array of models to try in order.
*/
models: ModelArgument[];
/**
* An array of `StatusName` values that should trigger a fallback.
* @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL']
*/
statuses?: StatusName[];
/**
* A callback to be executed on each fallback attempt.
*/
onError?: (error: Error) => void;
}

/**
* Creates a middleware that falls back to a different model on specific error statuses.
*
* ```ts
* const { text } = await ai.generate({
* model: googleAI.model('gemini-2.5-pro'),
* prompt: 'You are a helpful AI assistant named Walt, say hello',
* use: [
* fallback(ai, {
* models: [googleAI.model('gemini-2.5-flash')],
* statuses: ['RESOURCE_EXHAUSTED'],
* }),
* ],
* });
* ```
*/
export function fallback(
ai: HasRegistry,
options: FallbackOptions
): ModelMiddleware {
const { models, statuses = DEFAULT_RETRY_STATUSES, onError } = options;

return async (req, next) => {
try {
return await next(req);
} catch (e) {
if (e instanceof GenkitError && statuses.includes(e.status)) {
onError?.(e);
let lastError: any = e;
for (const model of models) {
try {
const resolved = await resolveModel(ai.registry, model);
return await resolved.modelAction(req);
} catch (e2) {
lastError = e2;
if (e2 instanceof GenkitError && statuses.includes(e2.status)) {
onError?.(e2);
continue;
}
throw e2;
}
}
throw lastError;
}
throw e;
}
};
}

export interface SimulatedConstrainedGenerationOptions {
instructionsRenderer?: (schema: Record<string, any>) => string;
}
Expand Down
15 changes: 10 additions & 5 deletions js/ai/tests/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,29 @@ export type ProgrammableModel = ModelAction & {
) => Promise<GenerateResponseData>;

lastRequest?: GenerateRequest;
requestCount: number;
};

export function defineProgrammableModel(
registry: Registry,
info?: ModelInfo
info?: ModelInfo,
name?: string
): ProgrammableModel {
const pm = defineModel(
registry,
{
...info,
name: 'programmableModel',
apiVersion: 'v2',
...(info as any),
name: name ?? 'programmableModel',
},
async (request, streamingCallback) => {
async (request, { sendChunk }) => {
pm.requestCount++;
pm.lastRequest = JSON.parse(JSON.stringify(request));
return pm.handleResponse(request, streamingCallback);
return pm.handleResponse(request, sendChunk);
}
) as ProgrammableModel;

pm.requestCount = 0;
return pm;
}

Expand Down
Loading