From f6f2708cd86868ca71917062b2f3128d309e7c10 Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Wed, 30 Apr 2025 19:59:32 -0700 Subject: [PATCH 1/5] fix (ai)!: move image model settings into providerOptions --- .../generate-image/generate-image.test.ts | 24 ++++--- .../ai/core/generate-image/generate-image.ts | 19 ++++-- .../src/bedrock-image-model.test.ts | 59 ++++++----------- .../amazon-bedrock/src/bedrock-image-model.ts | 6 +- .../src/bedrock-image-settings.ts | 8 --- .../src/bedrock-provider.test.ts | 10 +-- .../amazon-bedrock/src/bedrock-provider.ts | 29 ++++----- packages/azure/src/azure-openai-provider.ts | 15 ++--- .../src/deepinfra-image-model.test.ts | 21 +----- .../deepinfra/src/deepinfra-image-model.ts | 11 +--- .../deepinfra/src/deepinfra-image-settings.ts | 7 -- .../deepinfra/src/deepinfra-provider.test.ts | 6 +- packages/deepinfra/src/deepinfra-provider.ts | 23 ++----- packages/fal/src/fal-image-model.test.ts | 19 +----- packages/fal/src/fal-image-model.ts | 12 +--- packages/fal/src/fal-image-settings.ts | 7 -- packages/fal/src/fal-provider.test.ts | 20 ------ packages/fal/src/fal-provider.ts | 17 ++--- .../src/fireworks-image-model.test.ts | 40 +++--------- .../fireworks/src/fireworks-image-model.ts | 11 +--- .../fireworks/src/fireworks-image-options.ts | 7 -- .../fireworks/src/fireworks-provider.test.ts | 8 +-- packages/fireworks/src/fireworks-provider.ts | 23 ++----- .../src/google-vertex-image-model.test.ts | 31 ++------- .../src/google-vertex-image-model.ts | 13 +--- .../src/google-vertex-image-settings.ts | 7 -- .../src/google-vertex-provider.test.ts | 23 ------- .../src/google-vertex-provider.ts | 21 ++---- packages/luma/src/luma-image-model.test.ts | 46 +++++++------ packages/luma/src/luma-image-model.ts | 34 +++++----- packages/luma/src/luma-image-settings.ts | 5 -- packages/luma/src/luma-provider.test.ts | 20 ------ packages/luma/src/luma-provider.ts | 17 ++--- packages/openai-compatible/src/index.ts | 1 - .../src/openai-compatible-image-model.test.ts | 65 +++++++------------ .../src/openai-compatible-image-model.ts | 12 +--- .../src/openai-compatible-image-settings.ts | 13 ---- .../src/openai-compatible-provider.ts | 12 +--- .../openai/src/openai-image-model.test.ts | 25 +++---- packages/openai/src/openai-image-model.ts | 6 +- packages/openai/src/openai-image-settings.ts | 8 --- packages/openai/src/openai-provider.ts | 23 ++----- .../src/replicate-image-model.test.ts | 2 - .../replicate/src/replicate-image-model.ts | 12 +--- .../replicate/src/replicate-image-settings.ts | 7 -- packages/replicate/src/replicate-provider.ts | 23 ++----- .../src/togetherai-image-model.test.ts | 39 +++-------- .../togetherai/src/togetherai-image-model.ts | 11 +--- .../src/togetherai-image-settings.ts | 7 -- .../src/togetherai-provider.test.ts | 5 +- .../togetherai/src/togetherai-provider.ts | 26 +++----- packages/xai/src/xai-image-settings.ts | 7 -- packages/xai/src/xai-provider.test.ts | 10 ++- packages/xai/src/xai-provider.ts | 17 ++--- 54 files changed, 252 insertions(+), 698 deletions(-) diff --git a/packages/ai/core/generate-image/generate-image.test.ts b/packages/ai/core/generate-image/generate-image.test.ts index a2e44cd575e2..a6938d0246fe 100644 --- a/packages/ai/core/generate-image/generate-image.test.ts +++ b/packages/ai/core/generate-image/generate-image.test.ts @@ -61,7 +61,13 @@ describe('generateImage', () => { size: '1024x1024', aspectRatio: '16:9', seed: 12345, - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { + 'mock-provider': { + style: 'vivid', + // maxImagesPerCall is not passed to doGenerate + maxImagesPerCall: 3, + }, + }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal, }); @@ -72,7 +78,7 @@ describe('generateImage', () => { size: '1024x1024', aspectRatio: '16:9', seed: 12345, - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal, }); @@ -211,7 +217,9 @@ describe('generateImage', () => { seed: 12345, size: '1024x1024', aspectRatio: '16:9', - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { + 'mock-provider': { style: 'vivid' }, + }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal: undefined, }); @@ -225,7 +233,7 @@ describe('generateImage', () => { seed: 12345, size: '1024x1024', aspectRatio: '16:9', - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal: undefined, }); @@ -242,7 +250,7 @@ describe('generateImage', () => { size: '1024x1024', aspectRatio: '16:9', seed: 12345, - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, }); @@ -268,7 +276,7 @@ describe('generateImage', () => { seed: 12345, size: '1024x1024', aspectRatio: '16:9', - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal: undefined, }); @@ -283,7 +291,7 @@ describe('generateImage', () => { seed: 12345, size: '1024x1024', aspectRatio: '16:9', - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal: undefined, }); @@ -301,7 +309,7 @@ describe('generateImage', () => { size: '1024x1024', aspectRatio: '16:9', seed: 12345, - providerOptions: { openai: { style: 'vivid' } }, + providerOptions: { 'mock-provider': { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, }); diff --git a/packages/ai/core/generate-image/generate-image.ts b/packages/ai/core/generate-image/generate-image.ts index 86da2d793bc4..54fd9a29f125 100644 --- a/packages/ai/core/generate-image/generate-image.ts +++ b/packages/ai/core/generate-image/generate-image.ts @@ -110,20 +110,27 @@ Only applicable for HTTP-based providers. }): Promise { const { retry } = prepareRetries({ maxRetries: maxRetriesArg }); + // extract maxImagesPerCall from providerOptions as it's not meant to be passed to `doGenerate()` + const [ + [providerName, { maxImagesPerCall, ...generateProviderOptions } = {}], + ] = Object.entries(providerOptions ?? {}); + // default to 1 if the model has not specified limits on // how many images can be generated in a single call - const maxImagesPerCall = model.maxImagesPerCall ?? 1; + const maxImagesPerCallWithDefault = + (maxImagesPerCall as number) ?? model.maxImagesPerCall ?? 1; // parallelize calls to the model: - const callCount = Math.ceil(n / maxImagesPerCall); + const callCount = Math.ceil(n / maxImagesPerCallWithDefault); const callImageCounts = Array.from({ length: callCount }, (_, i) => { if (i < callCount - 1) { - return maxImagesPerCall; + return maxImagesPerCallWithDefault; } - const remainder = n % maxImagesPerCall; - return remainder === 0 ? maxImagesPerCall : remainder; + const remainder = n % maxImagesPerCallWithDefault; + return remainder === 0 ? maxImagesPerCallWithDefault : remainder; }); + const results = await Promise.all( callImageCounts.map(async callImageCount => retry(() => @@ -135,7 +142,7 @@ Only applicable for HTTP-based providers. size, aspectRatio, seed, - providerOptions: providerOptions ?? {}, + providerOptions: { [providerName]: generateProviderOptions }, }), ), ), diff --git a/packages/amazon-bedrock/src/bedrock-image-model.test.ts b/packages/amazon-bedrock/src/bedrock-image-model.test.ts index 3b083750032e..4a2bb55c2cc2 100644 --- a/packages/amazon-bedrock/src/bedrock-image-model.test.ts +++ b/packages/amazon-bedrock/src/bedrock-image-model.test.ts @@ -34,15 +34,11 @@ describe('doGenerate', () => { }, }); - const model = new BedrockImageModel( - 'amazon.nova-canvas-v1:0', - {}, - { - baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', - headers: mockConfigHeaders, - fetch: fakeFetchWithAuth, - }, - ); + const model = new BedrockImageModel('amazon.nova-canvas-v1:0', { + baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', + headers: mockConfigHeaders, + fetch: fakeFetchWithAuth, + }); it('should pass the model and the settings', async () => { await model.doGenerate({ @@ -84,21 +80,17 @@ describe('doGenerate', () => { 'shared-header': 'options-shared', }; - const modelWithHeaders = new BedrockImageModel( - 'amazon.nova-canvas-v1:0', - {}, - { - baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', - headers: { - 'model-header': 'model-value', - 'shared-header': 'model-shared', - }, - fetch: injectFetchHeaders({ - 'signed-header': 'signed-value', - authorization: 'AWS4-HMAC-SHA256...', - }), + const modelWithHeaders = new BedrockImageModel('amazon.nova-canvas-v1:0', { + baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', + headers: { + 'model-header': 'model-value', + 'shared-header': 'model-shared', }, - ); + fetch: injectFetchHeaders({ + 'signed-header': 'signed-value', + authorization: 'AWS4-HMAC-SHA256...', + }), + }); await modelWithHeaders.doGenerate({ prompt, @@ -119,11 +111,6 @@ describe('doGenerate', () => { }); it('should respect maxImagesPerCall setting', async () => { - const customModel = provider.image('amazon.nova-canvas-v1:0', { - maxImagesPerCall: 2, - }); - expect(customModel.maxImagesPerCall).toBe(2); - const defaultModel = provider.image('amazon.nova-canvas-v1:0'); expect(defaultModel.maxImagesPerCall).toBe(5); // 'amazon.nova-canvas-v1:0','s default from settings @@ -167,17 +154,13 @@ describe('doGenerate', () => { it('should include response data with timestamp, modelId and headers', async () => { const testDate = new Date('2024-03-15T12:00:00Z'); - const customModel = new BedrockImageModel( - 'amazon.nova-canvas-v1:0', - {}, - { - baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', - headers: () => ({}), - _internal: { - currentDate: () => testDate, - }, + const customModel = new BedrockImageModel('amazon.nova-canvas-v1:0', { + baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com', + headers: () => ({}), + _internal: { + currentDate: () => testDate, }, - ); + }); const result = await customModel.doGenerate({ prompt, diff --git a/packages/amazon-bedrock/src/bedrock-image-model.ts b/packages/amazon-bedrock/src/bedrock-image-model.ts index 246b204eb15e..381d38c51188 100644 --- a/packages/amazon-bedrock/src/bedrock-image-model.ts +++ b/packages/amazon-bedrock/src/bedrock-image-model.ts @@ -10,7 +10,6 @@ import { } from '@ai-sdk/provider-utils'; import { BedrockImageModelId, - BedrockImageSettings, modelMaxImagesPerCall, } from './bedrock-image-settings'; import { BedrockErrorSchema } from './bedrock-error'; @@ -30,9 +29,7 @@ export class BedrockImageModel implements ImageModelV2 { readonly provider = 'amazon-bedrock'; get maxImagesPerCall(): number { - return ( - this.settings.maxImagesPerCall ?? modelMaxImagesPerCall[this.modelId] ?? 1 - ); + return modelMaxImagesPerCall[this.modelId] ?? 1; } private getUrl(modelId: string): string { @@ -42,7 +39,6 @@ export class BedrockImageModel implements ImageModelV2 { constructor( readonly modelId: BedrockImageModelId, - private readonly settings: BedrockImageSettings, private readonly config: BedrockImageModelConfig, ) {} diff --git a/packages/amazon-bedrock/src/bedrock-image-settings.ts b/packages/amazon-bedrock/src/bedrock-image-settings.ts index 6caf11c27ab9..213d97eec586 100644 --- a/packages/amazon-bedrock/src/bedrock-image-settings.ts +++ b/packages/amazon-bedrock/src/bedrock-image-settings.ts @@ -4,11 +4,3 @@ export type BedrockImageModelId = 'amazon.nova-canvas-v1:0' | (string & {}); export const modelMaxImagesPerCall: Record = { 'amazon.nova-canvas-v1:0': 5, }; - -export interface BedrockImageSettings { - /** - * Override the maximum number of images per call (default is dependent on the - * model, or 1 for an unknown model). - */ - maxImagesPerCall?: number; -} diff --git a/packages/amazon-bedrock/src/bedrock-provider.test.ts b/packages/amazon-bedrock/src/bedrock-provider.test.ts index 54cc43962be2..76eefcc3d112 100644 --- a/packages/amazon-bedrock/src/bedrock-provider.test.ts +++ b/packages/amazon-bedrock/src/bedrock-provider.test.ts @@ -140,13 +140,10 @@ describe('AmazonBedrockProvider', () => { const provider = createAmazonBedrock(); const modelId = 'amazon.titan-image-generator'; - const model = provider.image(modelId, { - maxImagesPerCall: 5, - }); + const model = provider.image(modelId); const constructorCall = BedrockImageModelMock.mock.calls[0]; expect(constructorCall[0]).toBe(modelId); - expect(constructorCall[1]).toEqual({ maxImagesPerCall: 5 }); expect(model).toBeInstanceOf(BedrockImageModel); }); @@ -154,13 +151,10 @@ describe('AmazonBedrockProvider', () => { const provider = createAmazonBedrock(); const modelId = 'amazon.titan-image-generator'; - const model = provider.imageModel(modelId, { - maxImagesPerCall: 5, - }); + const model = provider.imageModel(modelId); const constructorCall = BedrockImageModelMock.mock.calls[0]; expect(constructorCall[0]).toBe(modelId); - expect(constructorCall[1]).toEqual({ maxImagesPerCall: 5 }); expect(model).toBeInstanceOf(BedrockImageModel); }); }); diff --git a/packages/amazon-bedrock/src/bedrock-provider.ts b/packages/amazon-bedrock/src/bedrock-provider.ts index 7226a9f7b0ae..18830ae188b8 100644 --- a/packages/amazon-bedrock/src/bedrock-provider.ts +++ b/packages/amazon-bedrock/src/bedrock-provider.ts @@ -16,10 +16,7 @@ import { BedrockChatModelId } from './bedrock-chat-options'; import { BedrockEmbeddingModel } from './bedrock-embedding-model'; import { BedrockEmbeddingModelId } from './bedrock-embedding-options'; import { BedrockImageModel } from './bedrock-image-model'; -import { - BedrockImageModelId, - BedrockImageSettings, -} from './bedrock-image-settings'; +import { BedrockImageModelId } from './bedrock-image-settings'; import { BedrockCredentials, createSigV4FetchFunction, @@ -85,15 +82,16 @@ export interface AmazonBedrockProvider extends ProviderV2 { embedding(modelId: BedrockEmbeddingModelId): EmbeddingModelV2; - image( - modelId: BedrockImageModelId, - settings?: BedrockImageSettings, - ): ImageModelV2; + /** +Creates a model for image generation. +@deprecated Use `imageModel` instead. + */ + image(modelId: BedrockImageModelId): ImageModelV2; - imageModel( - modelId: BedrockImageModelId, - settings?: BedrockImageSettings, - ): ImageModelV2; + /** +Creates a model for image generation. + */ + imageModel(modelId: BedrockImageModelId): ImageModelV2; } /** @@ -173,11 +171,8 @@ export function createAmazonBedrock( fetch: sigv4Fetch, }); - const createImageModel = ( - modelId: BedrockImageModelId, - settings: BedrockImageSettings = {}, - ) => - new BedrockImageModel(modelId, settings, { + const createImageModel = (modelId: BedrockImageModelId) => + new BedrockImageModel(modelId, { baseUrl: getBaseUrl, headers: options.headers ?? {}, fetch: sigv4Fetch, diff --git a/packages/azure/src/azure-openai-provider.ts b/packages/azure/src/azure-openai-provider.ts index f2dcc9eafe12..ae033990b32e 100644 --- a/packages/azure/src/azure-openai-provider.ts +++ b/packages/azure/src/azure-openai-provider.ts @@ -3,7 +3,6 @@ import { OpenAICompletionLanguageModel, OpenAIEmbeddingModel, OpenAIImageModel, - OpenAIImageSettings, OpenAIResponsesLanguageModel, OpenAITranscriptionModel, } from '@ai-sdk/openai/internal'; @@ -48,15 +47,12 @@ Creates an Azure OpenAI completion model for text generation. * Creates an Azure OpenAI DALL-E model for image generation. * @deprecated Use `imageModel` instead. */ - image(deploymentId: string, settings?: OpenAIImageSettings): ImageModelV2; + image(deploymentId: string): ImageModelV2; /** * Creates an Azure OpenAI DALL-E model for image generation. */ - imageModel( - deploymentId: string, - settings?: OpenAIImageSettings, - ): ImageModelV2; + imageModel(deploymentId: string): ImageModelV2; /** @deprecated Use `textEmbeddingModel` instead. @@ -183,11 +179,8 @@ export function createAzure( fetch: options.fetch, }); - const createImageModel = ( - modelId: string, - settings: OpenAIImageSettings = {}, - ) => - new OpenAIImageModel(modelId, settings, { + const createImageModel = (modelId: string) => + new OpenAIImageModel(modelId, { provider: 'azure.image', url, headers: getHeaders, diff --git a/packages/deepinfra/src/deepinfra-image-model.test.ts b/packages/deepinfra/src/deepinfra-image-model.test.ts index 4e383f67312f..1ad8d441026b 100644 --- a/packages/deepinfra/src/deepinfra-image-model.test.ts +++ b/packages/deepinfra/src/deepinfra-image-model.test.ts @@ -1,7 +1,6 @@ import { createTestServer } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { DeepInfraImageModel } from './deepinfra-image-model'; -import { DeepInfraImageSettings } from './deepinfra-image-settings'; import { FetchFunction } from '@ai-sdk/provider-utils'; const prompt = 'A cute baby sea otter'; @@ -10,14 +9,12 @@ function createBasicModel({ headers, fetch, currentDate, - settings, }: { headers?: () => Record; fetch?: FetchFunction; currentDate?: () => Date; - settings?: DeepInfraImageSettings; } = {}) { - return new DeepInfraImageModel('stability-ai/sdxl', settings ?? {}, { + return new DeepInfraImageModel('stability-ai/sdxl', { provider: 'deepinfra', baseURL: 'https://api.example.com', headers: headers ?? (() => ({ 'api-key': 'test-key' })), @@ -234,21 +231,5 @@ describe('DeepInfraImageModel', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(1); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 4, - }, - }); - - expect(model.maxImagesPerCall).toBe(4); - }); - - it('should default maxImagesPerCall to 1 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(1); - }); }); }); diff --git a/packages/deepinfra/src/deepinfra-image-model.ts b/packages/deepinfra/src/deepinfra-image-model.ts index f3ba26a228b8..8e06e308a612 100644 --- a/packages/deepinfra/src/deepinfra-image-model.ts +++ b/packages/deepinfra/src/deepinfra-image-model.ts @@ -6,10 +6,7 @@ import { createJsonResponseHandler, postJsonToApi, } from '@ai-sdk/provider-utils'; -import { - DeepInfraImageModelId, - DeepInfraImageSettings, -} from './deepinfra-image-settings'; +import { DeepInfraImageModelId } from './deepinfra-image-settings'; import { z } from 'zod'; interface DeepInfraImageModelConfig { @@ -24,18 +21,14 @@ interface DeepInfraImageModelConfig { export class DeepInfraImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + readonly maxImagesPerCall = 1; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: DeepInfraImageModelId, - readonly settings: DeepInfraImageSettings, private config: DeepInfraImageModelConfig, ) {} diff --git a/packages/deepinfra/src/deepinfra-image-settings.ts b/packages/deepinfra/src/deepinfra-image-settings.ts index b7f62d270a23..05c59c383f90 100644 --- a/packages/deepinfra/src/deepinfra-image-settings.ts +++ b/packages/deepinfra/src/deepinfra-image-settings.ts @@ -8,10 +8,3 @@ export type DeepInfraImageModelId = | 'stabilityai/sd3.5-medium' | 'stabilityai/sdxl-turbo' | (string & {}); - -export interface DeepInfraImageSettings { - /** - * Override the maximum number of images per call (default 1) - */ - maxImagesPerCall?: number; -} diff --git a/packages/deepinfra/src/deepinfra-provider.test.ts b/packages/deepinfra/src/deepinfra-provider.test.ts index b9c6a2c0e10d..53699538546f 100644 --- a/packages/deepinfra/src/deepinfra-provider.test.ts +++ b/packages/deepinfra/src/deepinfra-provider.test.ts @@ -136,14 +136,12 @@ describe('DeepInfraProvider', () => { it('should construct an image model with correct configuration', () => { const provider = createDeepInfra(); const modelId = 'deepinfra-image-model'; - const settings = { maxImagesPerCall: 2 }; - const model = provider.image(modelId, settings); + const model = provider.image(modelId); expect(model).toBeInstanceOf(DeepInfraImageModel); expect(DeepInfraImageModel).toHaveBeenCalledWith( modelId, - settings, expect.objectContaining({ provider: 'deepinfra.image', baseURL: 'https://api.deepinfra.com/v1/inference', @@ -160,7 +158,6 @@ describe('DeepInfraProvider', () => { expect(model).toBeInstanceOf(DeepInfraImageModel); expect(DeepInfraImageModel).toHaveBeenCalledWith( modelId, - {}, expect.any(Object), ); }); @@ -174,7 +171,6 @@ describe('DeepInfraProvider', () => { expect(DeepInfraImageModel).toHaveBeenCalledWith( modelId, - expect.any(Object), expect.objectContaining({ baseURL: `${customBaseURL}/inference`, }), diff --git a/packages/deepinfra/src/deepinfra-provider.ts b/packages/deepinfra/src/deepinfra-provider.ts index 5ad15894d351..b22cbf92ea16 100644 --- a/packages/deepinfra/src/deepinfra-provider.ts +++ b/packages/deepinfra/src/deepinfra-provider.ts @@ -17,10 +17,7 @@ import { import { DeepInfraChatModelId } from './deepinfra-chat-options'; import { DeepInfraEmbeddingModelId } from './deepinfra-embedding-options'; import { DeepInfraCompletionModelId } from './deepinfra-completion-options'; -import { - DeepInfraImageModelId, - DeepInfraImageSettings, -} from './deepinfra-image-settings'; +import { DeepInfraImageModelId } from './deepinfra-image-settings'; import { DeepInfraImageModel } from './deepinfra-image-model'; export interface DeepInfraProviderSettings { @@ -56,19 +53,14 @@ Creates a chat model for text generation. /** Creates a model for image generation. +@deprecated Use `imageModel` instead. */ - image( - modelId: DeepInfraImageModelId, - settings?: DeepInfraImageSettings, - ): ImageModelV2; + image(modelId: DeepInfraImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: DeepInfraImageModelId, - settings?: DeepInfraImageSettings, - ): ImageModelV2; + imageModel(modelId: DeepInfraImageModelId): ImageModelV2; /** Creates a chat model for text generation. @@ -136,11 +128,8 @@ export function createDeepInfra( getCommonModelConfig('embedding'), ); - const createImageModel = ( - modelId: DeepInfraImageModelId, - settings: DeepInfraImageSettings = {}, - ) => - new DeepInfraImageModel(modelId, settings, { + const createImageModel = (modelId: DeepInfraImageModelId) => + new DeepInfraImageModel(modelId, { ...getCommonModelConfig('image'), baseURL: baseURL ? `${baseURL}/inference` diff --git a/packages/fal/src/fal-image-model.test.ts b/packages/fal/src/fal-image-model.test.ts index d27ad2b9092b..fbbda82e0e26 100644 --- a/packages/fal/src/fal-image-model.test.ts +++ b/packages/fal/src/fal-image-model.test.ts @@ -9,14 +9,13 @@ function createBasicModel({ headers, fetch, currentDate, - settings, }: { headers?: Record; fetch?: FetchFunction; currentDate?: () => Date; settings?: any; } = {}) { - return new FalImageModel('stable-diffusion-xl', settings ?? {}, { + return new FalImageModel('stable-diffusion-xl', { provider: 'fal', baseURL: 'https://api.example.com', headers: headers ?? { 'api-key': 'test-key' }, @@ -185,22 +184,6 @@ describe('FalImageModel', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(1); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 4, - }, - }); - - expect(model.maxImagesPerCall).toBe(4); - }); - - it('should default maxImagesPerCall to 1 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(1); - }); }); describe('response schema validation', () => { diff --git a/packages/fal/src/fal-image-model.ts b/packages/fal/src/fal-image-model.ts index 0b212c1200ae..37403820378f 100644 --- a/packages/fal/src/fal-image-model.ts +++ b/packages/fal/src/fal-image-model.ts @@ -12,11 +12,7 @@ import { resolve, } from '@ai-sdk/provider-utils'; import { z } from 'zod'; -import { - FalImageModelId, - FalImageSettings, - FalImageSize, -} from './fal-image-settings'; +import { FalImageModelId, FalImageSize } from './fal-image-settings'; interface FalImageModelConfig { provider: string; @@ -30,18 +26,14 @@ interface FalImageModelConfig { export class FalImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + readonly maxImagesPerCall = 1; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: FalImageModelId, - private readonly settings: FalImageSettings, private readonly config: FalImageModelConfig, ) {} diff --git a/packages/fal/src/fal-image-settings.ts b/packages/fal/src/fal-image-settings.ts index 059f5c360b77..47d6e29b9abf 100644 --- a/packages/fal/src/fal-image-settings.ts +++ b/packages/fal/src/fal-image-settings.ts @@ -62,10 +62,3 @@ export type FalImageSize = width: number; height: number; }; - -export interface FalImageSettings { - /** -Override the maximum number of images per call (default 1). - */ - maxImagesPerCall?: number; -} diff --git a/packages/fal/src/fal-provider.test.ts b/packages/fal/src/fal-provider.test.ts index 4dfcfd9f707a..e4b6a6ec1133 100644 --- a/packages/fal/src/fal-provider.test.ts +++ b/packages/fal/src/fal-provider.test.ts @@ -21,25 +21,6 @@ describe('createFal', () => { expect(model).toBeInstanceOf(FalImageModel); expect(FalImageModel).toHaveBeenCalledWith( modelId, - {}, - expect.objectContaining({ - provider: 'fal.image', - baseURL: 'https://fal.run', - }), - ); - }); - - it('should construct an image model with custom settings', () => { - const provider = createFal(); - const modelId = 'fal-ai/flux/dev'; - const settings = { maxImagesPerCall: 3 }; - - const model = provider.image(modelId, settings); - - expect(model).toBeInstanceOf(FalImageModel); - expect(FalImageModel).toHaveBeenCalledWith( - modelId, - settings, expect.objectContaining({ provider: 'fal.image', baseURL: 'https://fal.run', @@ -64,7 +45,6 @@ describe('createFal', () => { expect(FalImageModel).toHaveBeenCalledWith( modelId, - {}, expect.objectContaining({ baseURL: customBaseURL, headers: expect.any(Function), diff --git a/packages/fal/src/fal-provider.ts b/packages/fal/src/fal-provider.ts index ed6989a5e9cd..75b7c10894c7 100644 --- a/packages/fal/src/fal-provider.ts +++ b/packages/fal/src/fal-provider.ts @@ -7,7 +7,7 @@ import { import type { FetchFunction } from '@ai-sdk/provider-utils'; import { withoutTrailingSlash } from '@ai-sdk/provider-utils'; import { FalImageModel } from './fal-image-model'; -import { FalImageModelId, FalImageSettings } from './fal-image-settings'; +import { FalImageModelId } from './fal-image-settings'; import { FalTranscriptionModelId } from './fal-transcription-options'; import { FalTranscriptionModel } from './fal-transcription-model'; @@ -39,16 +39,14 @@ requests, or to provide a custom fetch implementation for e.g. testing. export interface FalProvider extends ProviderV2 { /** Creates a model for image generation. +@deprecated Use `imageModel` instead. */ - image(modelId: FalImageModelId, settings?: FalImageSettings): ImageModelV2; + image(modelId: FalImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: FalImageModelId, - settings?: FalImageSettings, - ): ImageModelV2; + imageModel(modelId: FalImageModelId): ImageModelV2; /** Creates a model for transcription. @@ -111,11 +109,8 @@ export function createFal(options: FalProviderSettings = {}): FalProvider { ...options.headers, }); - const createImageModel = ( - modelId: FalImageModelId, - settings: FalImageSettings = {}, - ) => - new FalImageModel(modelId, settings, { + const createImageModel = (modelId: FalImageModelId) => + new FalImageModel(modelId, { provider: 'fal.image', baseURL: baseURL ?? defaultBaseURL, headers: getHeaders, diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts index ed3b68dba398..a85a62e393a5 100644 --- a/packages/fireworks/src/fireworks-image-model.test.ts +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -2,7 +2,6 @@ import { FetchFunction } from '@ai-sdk/provider-utils'; import { createTestServer } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { FireworksImageModel } from './fireworks-image-model'; -import { FireworksImageSettings } from './fireworks-image-options'; const prompt = 'A cute baby sea otter'; @@ -10,32 +9,25 @@ function createBasicModel({ headers, fetch, currentDate, - settings, }: { headers?: () => Record; fetch?: FetchFunction; currentDate?: () => Date; - settings?: FireworksImageSettings; } = {}) { - return new FireworksImageModel( - 'accounts/fireworks/models/flux-1-dev-fp8', - settings ?? {}, - { - provider: 'fireworks', - baseURL: 'https://api.example.com', - headers: headers ?? (() => ({ 'api-key': 'test-key' })), - fetch, - _internal: { - currentDate, - }, + return new FireworksImageModel('accounts/fireworks/models/flux-1-dev-fp8', { + provider: 'fireworks', + baseURL: 'https://api.example.com', + headers: headers ?? (() => ({ 'api-key': 'test-key' })), + fetch, + _internal: { + currentDate, }, - ); + }); } function createSizeModel() { return new FireworksImageModel( 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic', - {}, { provider: 'fireworks', baseURL: 'https://api.size-example.com', @@ -362,21 +354,5 @@ describe('FireworksImageModel', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(1); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 4, - }, - }); - - expect(model.maxImagesPerCall).toBe(4); - }); - - it('should default maxImagesPerCall to 1 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(1); - }); }); }); diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts index 224c0c65afb9..8a86cb34100f 100644 --- a/packages/fireworks/src/fireworks-image-model.ts +++ b/packages/fireworks/src/fireworks-image-model.ts @@ -6,10 +6,7 @@ import { FetchFunction, postJsonToApi, } from '@ai-sdk/provider-utils'; -import { - FireworksImageModelId, - FireworksImageSettings, -} from './fireworks-image-options'; +import { FireworksImageModelId } from './fireworks-image-options'; interface FireworksImageModelBackendConfig { urlFormat: 'workflows' | 'image_generation'; @@ -72,18 +69,14 @@ interface FireworksImageModelConfig { export class FireworksImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + readonly maxImagesPerCall = 1; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: FireworksImageModelId, - readonly settings: FireworksImageSettings, private config: FireworksImageModelConfig, ) {} diff --git a/packages/fireworks/src/fireworks-image-options.ts b/packages/fireworks/src/fireworks-image-options.ts index 909d6005910b..f720bb1fe1c2 100644 --- a/packages/fireworks/src/fireworks-image-options.ts +++ b/packages/fireworks/src/fireworks-image-options.ts @@ -8,10 +8,3 @@ export type FireworksImageModelId = | 'accounts/fireworks/models/SSD-1B' | 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0' | (string & {}); - -export interface FireworksImageSettings { - /** -Override the maximum number of images per call (default 1) - */ - maxImagesPerCall?: number; -} diff --git a/packages/fireworks/src/fireworks-provider.test.ts b/packages/fireworks/src/fireworks-provider.test.ts index b76ff1b3fceb..fa85a572f7ae 100644 --- a/packages/fireworks/src/fireworks-provider.test.ts +++ b/packages/fireworks/src/fireworks-provider.test.ts @@ -130,14 +130,12 @@ describe('FireworksProvider', () => { it('should construct an image model with correct configuration', () => { const provider = createFireworks(); const modelId = 'accounts/fireworks/models/flux-1-dev-fp8'; - const settings = { maxImagesPerCall: 2 }; - const model = provider.image(modelId, settings); + const model = provider.image(modelId); expect(model).toBeInstanceOf(FireworksImageModel); expect(FireworksImageModel).toHaveBeenCalledWith( modelId, - settings, expect.objectContaining({ provider: 'fireworks.image', baseURL: 'https://api.fireworks.ai/inference/v1', @@ -154,7 +152,6 @@ describe('FireworksProvider', () => { expect(model).toBeInstanceOf(FireworksImageModel); expect(FireworksImageModel).toHaveBeenCalledWith( modelId, - {}, expect.any(Object), ); }); @@ -164,11 +161,10 @@ describe('FireworksProvider', () => { const provider = createFireworks({ baseURL: customBaseURL }); const modelId = 'accounts/fireworks/models/flux-1-dev-fp8'; - const model = provider.image(modelId); + provider.image(modelId); expect(FireworksImageModel).toHaveBeenCalledWith( modelId, - expect.any(Object), expect.objectContaining({ baseURL: customBaseURL, }), diff --git a/packages/fireworks/src/fireworks-provider.ts b/packages/fireworks/src/fireworks-provider.ts index 054a9f86ee14..5a6b0b1dbd80 100644 --- a/packages/fireworks/src/fireworks-provider.ts +++ b/packages/fireworks/src/fireworks-provider.ts @@ -20,10 +20,7 @@ import { FireworksChatModelId } from './fireworks-chat-options'; import { FireworksCompletionModelId } from './fireworks-completion-options'; import { FireworksEmbeddingModelId } from './fireworks-embedding-options'; import { FireworksImageModel } from './fireworks-image-model'; -import { - FireworksImageModelId, - FireworksImageSettings, -} from './fireworks-image-options'; +import { FireworksImageModelId } from './fireworks-image-options'; export type FireworksErrorData = z.infer; @@ -87,19 +84,14 @@ Creates a text embedding model for text generation. /** Creates a model for image generation. +@deprecated Use `imageModel` instead. */ - image( - modelId: FireworksImageModelId, - settings?: FireworksImageSettings, - ): ImageModelV2; + image(modelId: FireworksImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: FireworksImageModelId, - settings?: FireworksImageSettings, - ): ImageModelV2; + imageModel(modelId: FireworksImageModelId): ImageModelV2; } const defaultBaseURL = 'https://api.fireworks.ai/inference/v1'; @@ -150,11 +142,8 @@ export function createFireworks( errorStructure: fireworksErrorStructure, }); - const createImageModel = ( - modelId: FireworksImageModelId, - settings: FireworksImageSettings = {}, - ) => - new FireworksImageModel(modelId, settings, { + const createImageModel = (modelId: FireworksImageModelId) => + new FireworksImageModel(modelId, { ...getCommonModelConfig('image'), baseURL: baseURL ?? defaultBaseURL, }); diff --git a/packages/google-vertex/src/google-vertex-image-model.test.ts b/packages/google-vertex/src/google-vertex-image-model.test.ts index 79935b79ed6a..32025afe8470 100644 --- a/packages/google-vertex/src/google-vertex-image-model.test.ts +++ b/packages/google-vertex/src/google-vertex-image-model.test.ts @@ -3,15 +3,11 @@ import { GoogleVertexImageModel } from './google-vertex-image-model'; const prompt = 'A cute baby sea otter'; -const model = new GoogleVertexImageModel( - 'imagen-3.0-generate-002', - {}, - { - provider: 'google-vertex', - baseURL: 'https://api.example.com', - headers: { 'api-key': 'test-key' }, - }, -); +const model = new GoogleVertexImageModel('imagen-3.0-generate-002', { + provider: 'google-vertex', + baseURL: 'https://api.example.com', + headers: { 'api-key': 'test-key' }, +}); const server = createTestServer({ 'https://api.example.com/models/imagen-3.0-generate-002:predict': {}, @@ -43,7 +39,6 @@ describe('GoogleVertexImageModel', () => { const modelWithHeaders = new GoogleVertexImageModel( 'imagen-3.0-generate-002', - {}, { provider: 'google-vertex', baseURL: 'https://api.example.com', @@ -72,24 +67,9 @@ describe('GoogleVertexImageModel', () => { }); }); - it('should respect maxImagesPerCall setting', () => { - const customModel = new GoogleVertexImageModel( - 'imagen-3.0-generate-002', - { maxImagesPerCall: 2 }, - { - provider: 'google-vertex', - baseURL: 'https://api.example.com', - headers: { 'api-key': 'test-key' }, - }, - ); - - expect(customModel.maxImagesPerCall).toBe(2); - }); - it('should use default maxImagesPerCall when not specified', () => { const defaultModel = new GoogleVertexImageModel( 'imagen-3.0-generate-002', - {}, { provider: 'google-vertex', baseURL: 'https://api.example.com', @@ -239,7 +219,6 @@ describe('GoogleVertexImageModel', () => { const customModel = new GoogleVertexImageModel( 'imagen-3.0-generate-002', - {}, { provider: 'google-vertex', baseURL: 'https://api.example.com', diff --git a/packages/google-vertex/src/google-vertex-image-model.ts b/packages/google-vertex/src/google-vertex-image-model.ts index f14f1330adbd..b491809e2097 100644 --- a/packages/google-vertex/src/google-vertex-image-model.ts +++ b/packages/google-vertex/src/google-vertex-image-model.ts @@ -9,10 +9,7 @@ import { } from '@ai-sdk/provider-utils'; import { z } from 'zod'; import { googleVertexFailedResponseHandler } from './google-vertex-error'; -import { - GoogleVertexImageModelId, - GoogleVertexImageSettings, -} from './google-vertex-image-settings'; +import { GoogleVertexImageModelId } from './google-vertex-image-settings'; interface GoogleVertexImageModelConfig { provider: string; @@ -27,19 +24,15 @@ interface GoogleVertexImageModelConfig { // https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images export class GoogleVertexImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list + readonly maxImagesPerCall = 4; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list - return this.settings.maxImagesPerCall ?? 4; - } - constructor( readonly modelId: GoogleVertexImageModelId, - readonly settings: GoogleVertexImageSettings, private config: GoogleVertexImageModelConfig, ) {} diff --git a/packages/google-vertex/src/google-vertex-image-settings.ts b/packages/google-vertex/src/google-vertex-image-settings.ts index 640c4bd9db20..d157e409a2fe 100644 --- a/packages/google-vertex/src/google-vertex-image-settings.ts +++ b/packages/google-vertex/src/google-vertex-image-settings.ts @@ -3,10 +3,3 @@ export type GoogleVertexImageModelId = | 'imagen-3.0-generate-002' | 'imagen-3.0-fast-generate-001' | (string & {}); - -export interface GoogleVertexImageSettings { - /** -Override the maximum number of images per call (default 4) - */ - maxImagesPerCall?: number; -} diff --git a/packages/google-vertex/src/google-vertex-provider.test.ts b/packages/google-vertex/src/google-vertex-provider.test.ts index 3226da4c379b..834f310cd5a4 100644 --- a/packages/google-vertex/src/google-vertex-provider.test.ts +++ b/packages/google-vertex/src/google-vertex-provider.test.ts @@ -146,7 +146,6 @@ describe('google-vertex-provider', () => { expect(GoogleVertexImageModel).toHaveBeenCalledWith( 'imagen-3.0-generate-002', - {}, expect.objectContaining({ provider: 'google.vertex.image', baseURL: @@ -155,26 +154,4 @@ describe('google-vertex-provider', () => { }), ); }); - - it('should create an image model with custom maxImagesPerCall', () => { - const provider = createVertex({ - project: 'test-project', - location: 'test-location', - }); - const imageSettings = { - maxImagesPerCall: 4, - }; - provider.image('imagen-3.0-generate-002', imageSettings); - - expect(GoogleVertexImageModel).toHaveBeenCalledWith( - 'imagen-3.0-generate-002', - imageSettings, - expect.objectContaining({ - provider: 'google.vertex.image', - headers: expect.any(Object), - baseURL: - 'https://test-location-aiplatform.googleapis.com/v1/projects/test-project/locations/test-location/publishers/google', - }), - ); - }); }); diff --git a/packages/google-vertex/src/google-vertex-provider.ts b/packages/google-vertex/src/google-vertex-provider.ts index 0f0d837423e8..b78c01aaa9fa 100644 --- a/packages/google-vertex/src/google-vertex-provider.ts +++ b/packages/google-vertex/src/google-vertex-provider.ts @@ -11,10 +11,7 @@ import { GoogleVertexConfig } from './google-vertex-config'; import { GoogleVertexEmbeddingModel } from './google-vertex-embedding-model'; import { GoogleVertexEmbeddingModelId } from './google-vertex-embedding-options'; import { GoogleVertexImageModel } from './google-vertex-image-model'; -import { - GoogleVertexImageModelId, - GoogleVertexImageSettings, -} from './google-vertex-image-settings'; +import { GoogleVertexImageModelId } from './google-vertex-image-settings'; import { GoogleVertexModelId } from './google-vertex-options'; export interface GoogleVertexProvider extends ProviderV2 { @@ -28,18 +25,12 @@ Creates a model for text generation. /** * Creates a model for image generation. */ - image( - modelId: GoogleVertexImageModelId, - settings?: GoogleVertexImageSettings, - ): ImageModelV2; + image(modelId: GoogleVertexImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: GoogleVertexImageModelId, - settings?: GoogleVertexImageSettings, - ): ImageModelV2; + imageModel(modelId: GoogleVertexImageModelId): ImageModelV2; } export interface GoogleVertexProviderSettings { @@ -135,10 +126,8 @@ export function createVertex( const createEmbeddingModel = (modelId: GoogleVertexEmbeddingModelId) => new GoogleVertexEmbeddingModel(modelId, createConfig('embedding')); - const createImageModel = ( - modelId: GoogleVertexImageModelId, - settings: GoogleVertexImageSettings = {}, - ) => new GoogleVertexImageModel(modelId, settings, createConfig('image')); + const createImageModel = (modelId: GoogleVertexImageModelId) => + new GoogleVertexImageModel(modelId, createConfig('image')); const provider = function (modelId: GoogleVertexModelId) { if (new.target) { diff --git a/packages/luma/src/luma-image-model.test.ts b/packages/luma/src/luma-image-model.test.ts index e492055b446d..4599816c33bb 100644 --- a/packages/luma/src/luma-image-model.test.ts +++ b/packages/luma/src/luma-image-model.test.ts @@ -10,14 +10,12 @@ function createBasicModel({ headers, fetch, currentDate, - settings, }: { headers?: () => Record; fetch?: FetchFunction; currentDate?: () => Date; - settings?: any; } = {}) { - return new LumaImageModel('test-model', settings ?? {}, { + return new LumaImageModel('test-model', { provider: 'luma', baseURL: 'https://api.example.com', headers: headers ?? (() => ({ 'api-key': 'test-key' })), @@ -148,6 +146,32 @@ describe('LumaImageModel', () => { }); }); + it.only('should not pass providerOptions.{pollIntervalMillis,maxPollAttempts}', async () => { + const model = createBasicModel(); + + await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: '16:9', + seed: undefined, + providerOptions: { + luma: { + pollIntervalMillis: 1000, + maxPollAttempts: 5, + additional_param: 'value', + }, + }, + }); + + expect(await server.calls[0].requestBody).toStrictEqual({ + prompt, + aspect_ratio: '16:9', + model: 'test-model', + additional_param: 'value', + }); + }); + it('should handle API errors', async () => { server.urls[ 'https://api.example.com/dream-machine/v1/generations/image' @@ -273,22 +297,6 @@ describe('LumaImageModel', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(1); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 4, - }, - }); - - expect(model.maxImagesPerCall).toBe(4); - }); - - it('should default maxImagesPerCall to 1 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(1); - }); }); describe('response schema validation', () => { diff --git a/packages/luma/src/luma-image-model.ts b/packages/luma/src/luma-image-model.ts index 11d8de087927..8046fd5aba95 100644 --- a/packages/luma/src/luma-image-model.ts +++ b/packages/luma/src/luma-image-model.ts @@ -32,28 +32,18 @@ interface LumaImageModelConfig { export class LumaImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; - - private readonly pollIntervalMillis: number; - private readonly maxPollAttempts: number; + readonly maxImagesPerCall = 1; + readonly pollIntervalMillis = DEFAULT_POLL_INTERVAL_MILLIS; + readonly maxPollAttempts = DEFAULT_MAX_POLL_ATTEMPTS; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: string, - private readonly settings: LumaImageSettings, private readonly config: LumaImageModelConfig, - ) { - this.pollIntervalMillis = - settings.pollIntervalMillis ?? DEFAULT_POLL_INTERVAL_MILLIS; - this.maxPollAttempts = - settings.maxPollAttempts ?? DEFAULT_MAX_POLL_ATTEMPTS; - } + ) {} async doGenerate({ prompt, @@ -86,6 +76,10 @@ export class LumaImageModel implements ImageModelV2 { }); } + // remove non-request options from providerOptions + const { pollIntervalMillis, maxPollAttempts, ...providerRequestOptions } = + providerOptions.luma ?? {}; + const currentDate = this.config._internal?.currentDate?.() ?? new Date(); const fullHeaders = combineHeaders(this.config.headers(), headers); const { value: generationResponse, responseHeaders } = await postJsonToApi({ @@ -95,7 +89,7 @@ export class LumaImageModel implements ImageModelV2 { prompt, ...(aspectRatio ? { aspect_ratio: aspectRatio } : {}), model: this.modelId, - ...(providerOptions.luma ?? {}), + ...providerRequestOptions, }, abortSignal, fetch: this.config.fetch, @@ -109,6 +103,7 @@ export class LumaImageModel implements ImageModelV2 { generationResponse.id, fullHeaders, abortSignal, + providerOptions.luma, ); const downloadedImage = await this.downloadImage(imageUrl, abortSignal); @@ -128,10 +123,15 @@ export class LumaImageModel implements ImageModelV2 { generationId: string, headers: Record, abortSignal: AbortSignal | undefined, + imageSettings?: LumaImageSettings, ): Promise { let attemptCount = 0; const url = this.getLumaGenerationsUrl(generationId); - for (let i = 0; i < this.maxPollAttempts; i++) { + const maxPollAttempts = + imageSettings?.maxPollAttempts ?? this.maxPollAttempts; + const pollIntervalMillis = + imageSettings?.pollIntervalMillis ?? this.pollIntervalMillis; + for (let i = 0; i < maxPollAttempts; i++) { const { value: statusResponse } = await getFromApi({ url, headers, @@ -158,7 +158,7 @@ export class LumaImageModel implements ImageModelV2 { message: `Image generation failed.`, }); } - await delay(this.pollIntervalMillis); + await delay(pollIntervalMillis); } throw new Error( diff --git a/packages/luma/src/luma-image-settings.ts b/packages/luma/src/luma-image-settings.ts index 6231504cf942..1cd1579f35d7 100644 --- a/packages/luma/src/luma-image-settings.ts +++ b/packages/luma/src/luma-image-settings.ts @@ -9,11 +9,6 @@ settings allow you to tune the polling behavior when waiting for image generation to complete. */ export interface LumaImageSettings { - /** -Override the maximum number of images per call (default 1) - */ - maxImagesPerCall?: number; - /** Override the polling interval in milliseconds (default 500). This controls how frequently the API is checked for completed images while they are being diff --git a/packages/luma/src/luma-provider.test.ts b/packages/luma/src/luma-provider.test.ts index a804e4c26ec6..661d5f7d4dd5 100644 --- a/packages/luma/src/luma-provider.test.ts +++ b/packages/luma/src/luma-provider.test.ts @@ -21,25 +21,6 @@ describe('createLuma', () => { expect(model).toBeInstanceOf(LumaImageModel); expect(LumaImageModel).toHaveBeenCalledWith( modelId, - {}, - expect.objectContaining({ - provider: 'luma.image', - baseURL: 'https://api.lumalabs.ai', - }), - ); - }); - - it('should construct an image model with custom settings', () => { - const provider = createLuma(); - const modelId = 'luma-v1'; - const settings = { maxImagesPerCall: 2 }; - - const model = provider.image(modelId, settings); - - expect(model).toBeInstanceOf(LumaImageModel); - expect(LumaImageModel).toHaveBeenCalledWith( - modelId, - settings, expect.objectContaining({ provider: 'luma.image', baseURL: 'https://api.lumalabs.ai', @@ -64,7 +45,6 @@ describe('createLuma', () => { expect(LumaImageModel).toHaveBeenCalledWith( modelId, - {}, expect.objectContaining({ baseURL: customBaseURL, headers: expect.any(Function), diff --git a/packages/luma/src/luma-provider.ts b/packages/luma/src/luma-provider.ts index 221b29db2a93..cb489db32a12 100644 --- a/packages/luma/src/luma-provider.ts +++ b/packages/luma/src/luma-provider.ts @@ -5,7 +5,7 @@ import { withoutTrailingSlash, } from '@ai-sdk/provider-utils'; import { LumaImageModel } from './luma-image-model'; -import { LumaImageModelId, LumaImageSettings } from './luma-image-settings'; +import { LumaImageModelId } from './luma-image-settings'; export interface LumaProviderSettings { /** @@ -31,16 +31,14 @@ or to provide a custom fetch implementation for e.g. testing. export interface LumaProvider extends ProviderV2 { /** Creates a model for image generation. +@deprecated Use `imageModel` instead. */ - image(modelId: LumaImageModelId, settings?: LumaImageSettings): ImageModelV2; + image(modelId: LumaImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: LumaImageModelId, - settings?: LumaImageSettings, - ): ImageModelV2; + imageModel(modelId: LumaImageModelId): ImageModelV2; } const defaultBaseURL = 'https://api.lumalabs.ai'; @@ -56,11 +54,8 @@ export function createLuma(options: LumaProviderSettings = {}): LumaProvider { ...options.headers, }); - const createImageModel = ( - modelId: LumaImageModelId, - settings: LumaImageSettings = {}, - ) => - new LumaImageModel(modelId, settings, { + const createImageModel = (modelId: LumaImageModelId) => + new LumaImageModel(modelId, { provider: 'luma.image', baseURL: baseURL ?? defaultBaseURL, headers: getHeaders, diff --git a/packages/openai-compatible/src/index.ts b/packages/openai-compatible/src/index.ts index e7a8bb322b45..825575714231 100644 --- a/packages/openai-compatible/src/index.ts +++ b/packages/openai-compatible/src/index.ts @@ -14,7 +14,6 @@ export type { OpenAICompatibleEmbeddingProviderOptions, } from './openai-compatible-embedding-options'; export { OpenAICompatibleImageModel } from './openai-compatible-image-model'; -export type { OpenAICompatibleImageSettings } from './openai-compatible-image-settings'; export type { OpenAICompatibleErrorData, ProviderErrorStructure, diff --git a/packages/openai-compatible/src/openai-compatible-image-model.test.ts b/packages/openai-compatible/src/openai-compatible-image-model.test.ts index 105ffb970a67..9ae98105d256 100644 --- a/packages/openai-compatible/src/openai-compatible-image-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-image-model.test.ts @@ -11,16 +11,14 @@ function createBasicModel({ headers, fetch, currentDate, - settings, errorStructure, }: { headers?: () => Record; fetch?: FetchFunction; currentDate?: () => Date; - settings?: any; errorStructure?: ProviderErrorStructure; } = {}) { - return new OpenAICompatibleImageModel('dall-e-3', settings ?? {}, { + return new OpenAICompatibleImageModel('dall-e-3', { provider: 'openai-compatible', headers: headers ?? (() => ({ Authorization: 'Bearer test-key' })), url: ({ modelId, path }) => `https://api.example.com/${modelId}${path}`, @@ -74,22 +72,6 @@ describe('OpenAICompatibleImageModel', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(10); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 5, - }, - }); - - expect(model.maxImagesPerCall).toBe(5); - }); - - it('should default maxImagesPerCall to 10 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(10); - }); }); describe('doGenerate', () => { @@ -259,16 +241,11 @@ describe('OpenAICompatibleImageModel', () => { it('should use real date when no custom date provider is specified', async () => { const beforeDate = new Date(); - const model = new OpenAICompatibleImageModel( - 'dall-e-3', - {}, - { - provider: 'openai-compatible', - headers: () => ({ Authorization: 'Bearer test-key' }), - url: ({ modelId, path }) => - `https://api.example.com/${modelId}${path}`, - }, - ); + const model = new OpenAICompatibleImageModel('dall-e-3', { + provider: 'openai-compatible', + headers: () => ({ Authorization: 'Bearer test-key' }), + url: ({ modelId, path }) => `https://api.example.com/${modelId}${path}`, + }); const result = await model.doGenerate(createDefaultGenerateParams()); @@ -284,13 +261,17 @@ describe('OpenAICompatibleImageModel', () => { }); it('should pass the user setting in the request', async () => { - const model = createBasicModel({ - settings: { - user: 'test-user-id', - }, - }); + const model = createBasicModel(); - await model.doGenerate(createDefaultGenerateParams()); + await model.doGenerate( + createDefaultGenerateParams({ + providerOptions: { + openai: { + user: 'test-user-id', + }, + }, + }), + ); expect(await server.calls[0].requestBody).toStrictEqual({ model: 'dall-e-3', @@ -302,12 +283,16 @@ describe('OpenAICompatibleImageModel', () => { }); }); - it('should not include user field in request when setting is not provided', async () => { - const model = createBasicModel({ - settings: {}, // explicitly empty settings - }); + it('should not include user field in request when not set via provider options', async () => { + const model = createBasicModel(); - await model.doGenerate(createDefaultGenerateParams()); + await model.doGenerate( + createDefaultGenerateParams({ + providerOptions: { + openai: {}, + }, + }), + ); const requestBody = await server.calls[0].requestBody; expect(requestBody).toStrictEqual({ diff --git a/packages/openai-compatible/src/openai-compatible-image-model.ts b/packages/openai-compatible/src/openai-compatible-image-model.ts index c2d16064b585..ddbf0c9c167f 100644 --- a/packages/openai-compatible/src/openai-compatible-image-model.ts +++ b/packages/openai-compatible/src/openai-compatible-image-model.ts @@ -11,10 +11,7 @@ import { defaultOpenAICompatibleErrorStructure, ProviderErrorStructure, } from './openai-compatible-error'; -import { - OpenAICompatibleImageModelId, - OpenAICompatibleImageSettings, -} from './openai-compatible-image-settings'; +import { OpenAICompatibleImageModelId } from './openai-compatible-image-settings'; export type OpenAICompatibleImageModelConfig = { provider: string; @@ -29,10 +26,7 @@ export type OpenAICompatibleImageModelConfig = { export class OpenAICompatibleImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; - - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 10; - } + readonly maxImagesPerCall = 10; get provider(): string { return this.config.provider; @@ -40,7 +34,6 @@ export class OpenAICompatibleImageModel implements ImageModelV2 { constructor( readonly modelId: OpenAICompatibleImageModelId, - private readonly settings: OpenAICompatibleImageSettings, private readonly config: OpenAICompatibleImageModelConfig, ) {} @@ -85,7 +78,6 @@ export class OpenAICompatibleImageModel implements ImageModelV2 { size, ...(providerOptions.openai ?? {}), response_format: 'b64_json', - ...(this.settings.user ? { user: this.settings.user } : {}), }, failedResponseHandler: createJsonErrorResponseHandler( this.config.errorStructure ?? defaultOpenAICompatibleErrorStructure, diff --git a/packages/openai-compatible/src/openai-compatible-image-settings.ts b/packages/openai-compatible/src/openai-compatible-image-settings.ts index cfc00c7ec1d8..463fd565307a 100644 --- a/packages/openai-compatible/src/openai-compatible-image-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-image-settings.ts @@ -1,14 +1 @@ export type OpenAICompatibleImageModelId = string; - -export interface OpenAICompatibleImageSettings { - /** -A unique identifier representing your end-user, which can help the provider to -monitor and detect abuse. - */ - user?: string; - - /** - * The maximum number of images to generate. - */ - maxImagesPerCall?: number; -} diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index b50876b396c3..23b7331b3f53 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -8,7 +8,6 @@ import { FetchFunction, withoutTrailingSlash } from '@ai-sdk/provider-utils'; import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model'; import { OpenAICompatibleCompletionLanguageModel } from './openai-compatible-completion-language-model'; import { OpenAICompatibleEmbeddingModel } from './openai-compatible-embedding-model'; -import { OpenAICompatibleImageSettings } from './openai-compatible-image-settings'; import { OpenAICompatibleImageModel } from './openai-compatible-image-model'; export interface OpenAICompatibleProvider< @@ -130,15 +129,8 @@ export function createOpenAICompatible< ...getCommonModelConfig('embedding'), }); - const createImageModel = ( - modelId: IMAGE_MODEL_IDS, - settings: OpenAICompatibleImageSettings = {}, - ) => - new OpenAICompatibleImageModel( - modelId, - settings, - getCommonModelConfig('image'), - ); + const createImageModel = (modelId: IMAGE_MODEL_IDS) => + new OpenAICompatibleImageModel(modelId, getCommonModelConfig('image')); const provider = (modelId: CHAT_MODEL_IDS) => createLanguageModel(modelId); diff --git a/packages/openai/src/openai-image-model.test.ts b/packages/openai/src/openai-image-model.test.ts index 38cb15c51ee8..930a449bfbb6 100644 --- a/packages/openai/src/openai-image-model.test.ts +++ b/packages/openai/src/openai-image-model.test.ts @@ -5,7 +5,7 @@ import { createOpenAI } from './openai-provider'; const prompt = 'A cute baby sea otter'; const provider = createOpenAI({ apiKey: 'test-api-key' }); -const model = provider.image('dall-e-3', { maxImagesPerCall: 2 }); +const model = provider.image('dall-e-3'); const server = createTestServer({ 'https://api.openai.com/v1/images/generations': {}, @@ -134,11 +134,6 @@ describe('doGenerate', () => { }); it('should respect maxImagesPerCall setting', async () => { - prepareJsonResponse(); - - const customModel = provider.image('dall-e-2', { maxImagesPerCall: 5 }); - expect(customModel.maxImagesPerCall).toBe(5); - const defaultModel = provider.image('dall-e-2'); expect(defaultModel.maxImagesPerCall).toBe(10); // dall-e-2's default from settings @@ -156,18 +151,14 @@ describe('doGenerate', () => { const testDate = new Date('2024-03-15T12:00:00Z'); - const customModel = new OpenAIImageModel( - 'dall-e-3', - {}, - { - provider: 'test-provider', - url: () => 'https://api.openai.com/v1/images/generations', - headers: () => ({}), - _internal: { - currentDate: () => testDate, - }, + const customModel = new OpenAIImageModel('dall-e-3', { + provider: 'test-provider', + url: () => 'https://api.openai.com/v1/images/generations', + headers: () => ({}), + _internal: { + currentDate: () => testDate, }, - ); + }); const result = await customModel.doGenerate({ prompt, diff --git a/packages/openai/src/openai-image-model.ts b/packages/openai/src/openai-image-model.ts index 1e85e94f69ea..2f38fd5616a1 100644 --- a/packages/openai/src/openai-image-model.ts +++ b/packages/openai/src/openai-image-model.ts @@ -9,7 +9,6 @@ import { OpenAIConfig } from './openai-config'; import { openaiFailedResponseHandler } from './openai-error'; import { OpenAIImageModelId, - OpenAIImageSettings, modelMaxImagesPerCall, hasDefaultResponseFormat, } from './openai-image-settings'; @@ -24,9 +23,7 @@ export class OpenAIImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; get maxImagesPerCall(): number { - return ( - this.settings.maxImagesPerCall ?? modelMaxImagesPerCall[this.modelId] ?? 1 - ); + return modelMaxImagesPerCall[this.modelId] ?? 1; } get provider(): string { @@ -35,7 +32,6 @@ export class OpenAIImageModel implements ImageModelV2 { constructor( readonly modelId: OpenAIImageModelId, - private readonly settings: OpenAIImageSettings, private readonly config: OpenAIImageModelConfig, ) {} diff --git a/packages/openai/src/openai-image-settings.ts b/packages/openai/src/openai-image-settings.ts index 19fa72632783..8b0db08ecb29 100644 --- a/packages/openai/src/openai-image-settings.ts +++ b/packages/openai/src/openai-image-settings.ts @@ -12,11 +12,3 @@ export const modelMaxImagesPerCall: Record = { }; export const hasDefaultResponseFormat = new Set(['gpt-image-1']); - -export interface OpenAIImageSettings { - /** -Override the maximum number of images per call (default is dependent on the -model, or 1 for an unknown model). - */ - maxImagesPerCall?: number; -} diff --git a/packages/openai/src/openai-provider.ts b/packages/openai/src/openai-provider.ts index 43037f889a22..b38ed4e1ae6b 100644 --- a/packages/openai/src/openai-provider.ts +++ b/packages/openai/src/openai-provider.ts @@ -18,10 +18,7 @@ import { OpenAICompletionModelId } from './openai-completion-options'; import { OpenAIEmbeddingModel } from './openai-embedding-model'; import { OpenAIEmbeddingModelId } from './openai-embedding-options'; import { OpenAIImageModel } from './openai-image-model'; -import { - OpenAIImageModelId, - OpenAIImageSettings, -} from './openai-image-settings'; +import { OpenAIImageModelId } from './openai-image-settings'; import { openaiTools } from './openai-tools'; import { OpenAITranscriptionModel } from './openai-transcription-model'; import { OpenAITranscriptionModelId } from './openai-transcription-options'; @@ -76,19 +73,14 @@ Creates a model for text embeddings. /** Creates a model for image generation. +@deprecated Use `imageModel` instead. */ - image( - modelId: OpenAIImageModelId, - settings?: OpenAIImageSettings, - ): ImageModelV2; + image(modelId: OpenAIImageModelId): ImageModelV2; /** Creates a model for image generation. */ - imageModel( - modelId: OpenAIImageModelId, - settings?: OpenAIImageSettings, - ): ImageModelV2; + imageModel(modelId: OpenAIImageModelId): ImageModelV2; /** Creates a model for transcription. @@ -202,11 +194,8 @@ export function createOpenAI( fetch: options.fetch, }); - const createImageModel = ( - modelId: OpenAIImageModelId, - settings: OpenAIImageSettings = {}, - ) => - new OpenAIImageModel(modelId, settings, { + const createImageModel = (modelId: OpenAIImageModelId) => + new OpenAIImageModel(modelId, { provider: `${providerName}.image`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, diff --git a/packages/replicate/src/replicate-image-model.test.ts b/packages/replicate/src/replicate-image-model.test.ts index ddc41bbba506..155fb3645d0a 100644 --- a/packages/replicate/src/replicate-image-model.test.ts +++ b/packages/replicate/src/replicate-image-model.test.ts @@ -180,7 +180,6 @@ describe('doGenerate', () => { it('should return response metadata', async () => { const modelWithTimestamp = new ReplicateImageModel( 'black-forest-labs/flux-schnell', - {}, { provider: 'replicate', baseURL: 'https://api.replicate.com', @@ -210,7 +209,6 @@ describe('doGenerate', () => { it('should include response headers in metadata', async () => { const modelWithTimestamp = new ReplicateImageModel( 'black-forest-labs/flux-schnell', - {}, { provider: 'replicate', baseURL: 'https://api.replicate.com', diff --git a/packages/replicate/src/replicate-image-model.ts b/packages/replicate/src/replicate-image-model.ts index 0c2cc3eaea45..26919565fcf0 100644 --- a/packages/replicate/src/replicate-image-model.ts +++ b/packages/replicate/src/replicate-image-model.ts @@ -11,10 +11,7 @@ import { } from '@ai-sdk/provider-utils'; import { z } from 'zod'; import { replicateFailedResponseHandler } from './replicate-error'; -import { - ReplicateImageModelId, - ReplicateImageSettings, -} from './replicate-image-settings'; +import { ReplicateImageModelId } from './replicate-image-settings'; interface ReplicateImageModelConfig { provider: string; @@ -28,18 +25,14 @@ interface ReplicateImageModelConfig { export class ReplicateImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + readonly maxImagesPerCall = 1; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: ReplicateImageModelId, - private readonly settings: ReplicateImageSettings, private readonly config: ReplicateImageModelConfig, ) {} @@ -60,6 +53,7 @@ export class ReplicateImageModel implements ImageModelV2 { const [modelId, version] = this.modelId.split(':'); const currentDate = this.config._internal?.currentDate?.() ?? new Date(); + const { value: { output }, responseHeaders, diff --git a/packages/replicate/src/replicate-image-settings.ts b/packages/replicate/src/replicate-image-settings.ts index 259880f59d7e..eb56d5e538cc 100644 --- a/packages/replicate/src/replicate-image-settings.ts +++ b/packages/replicate/src/replicate-image-settings.ts @@ -27,10 +27,3 @@ export type ReplicateImageModelId = | 'stability-ai/stable-diffusion-3.5-medium' | 'tstramer/material-diffusion' | (string & {}); - -export interface ReplicateImageSettings { - /** -Override the maximum number of images per call (default 1) - */ - maxImagesPerCall?: number; -} diff --git a/packages/replicate/src/replicate-provider.ts b/packages/replicate/src/replicate-provider.ts index 3271dd8a4161..7c8c790fa87d 100644 --- a/packages/replicate/src/replicate-provider.ts +++ b/packages/replicate/src/replicate-provider.ts @@ -2,10 +2,7 @@ import { NoSuchModelError, ProviderV2 } from '@ai-sdk/provider'; import type { FetchFunction } from '@ai-sdk/provider-utils'; import { loadApiKey } from '@ai-sdk/provider-utils'; import { ReplicateImageModel } from './replicate-image-model'; -import { - ReplicateImageModelId, - ReplicateImageSettings, -} from './replicate-image-settings'; +import { ReplicateImageModelId } from './replicate-image-settings'; export interface ReplicateProviderSettings { /** @@ -35,19 +32,14 @@ or to provide a custom fetch implementation for e.g. testing. export interface ReplicateProvider extends ProviderV2 { /** * Creates a Replicate image generation model. + * @deprecated Use `imageModel` instead. */ - image( - modelId: ReplicateImageModelId, - settings?: ReplicateImageSettings, - ): ReplicateImageModel; + image(modelId: ReplicateImageModelId): ReplicateImageModel; /** * Creates a Replicate image generation model. */ - imageModel( - modelId: ReplicateImageModelId, - settings?: ReplicateImageSettings, - ): ReplicateImageModel; + imageModel(modelId: ReplicateImageModelId): ReplicateImageModel; } /** @@ -56,11 +48,8 @@ export interface ReplicateProvider extends ProviderV2 { export function createReplicate( options: ReplicateProviderSettings = {}, ): ReplicateProvider { - const createImageModel = ( - modelId: ReplicateImageModelId, - settings?: ReplicateImageSettings, - ) => - new ReplicateImageModel(modelId, settings ?? {}, { + const createImageModel = (modelId: ReplicateImageModelId) => + new ReplicateImageModel(modelId, { provider: 'replicate', baseURL: options.baseURL ?? 'https://api.replicate.com/v1', headers: { diff --git a/packages/togetherai/src/togetherai-image-model.test.ts b/packages/togetherai/src/togetherai-image-model.test.ts index 738da951302f..a50fb87d90bb 100644 --- a/packages/togetherai/src/togetherai-image-model.test.ts +++ b/packages/togetherai/src/togetherai-image-model.test.ts @@ -2,7 +2,6 @@ import { FetchFunction } from '@ai-sdk/provider-utils'; import { createTestServer } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { TogetherAIImageModel } from './togetherai-image-model'; -import { TogetherAIImageSettings } from './togetherai-image-settings'; const prompt = 'A cute baby sea otter'; @@ -10,26 +9,20 @@ function createBasicModel({ headers, fetch, currentDate, - settings, }: { headers?: () => Record; fetch?: FetchFunction; currentDate?: () => Date; - settings?: TogetherAIImageSettings; } = {}) { - return new TogetherAIImageModel( - 'stabilityai/stable-diffusion-xl', - settings ?? {}, - { - provider: 'togetherai', - baseURL: 'https://api.example.com', - headers: headers ?? (() => ({ 'api-key': 'test-key' })), - fetch, - _internal: { - currentDate, - }, + return new TogetherAIImageModel('stabilityai/stable-diffusion-xl', { + provider: 'togetherai', + baseURL: 'https://api.example.com', + headers: headers ?? (() => ({ 'api-key': 'test-key' })), + fetch, + _internal: { + currentDate, }, - ); + }); } const server = createTestServer({ @@ -248,20 +241,4 @@ describe('constructor', () => { expect(model.specificationVersion).toBe('v2'); expect(model.maxImagesPerCall).toBe(1); }); - - it('should use maxImagesPerCall from settings', () => { - const model = createBasicModel({ - settings: { - maxImagesPerCall: 4, - }, - }); - - expect(model.maxImagesPerCall).toBe(4); - }); - - it('should default maxImagesPerCall to 1 when not specified', () => { - const model = createBasicModel(); - - expect(model.maxImagesPerCall).toBe(1); - }); }); diff --git a/packages/togetherai/src/togetherai-image-model.ts b/packages/togetherai/src/togetherai-image-model.ts index f29768c31ccc..90d78b391657 100644 --- a/packages/togetherai/src/togetherai-image-model.ts +++ b/packages/togetherai/src/togetherai-image-model.ts @@ -6,10 +6,7 @@ import { FetchFunction, postJsonToApi, } from '@ai-sdk/provider-utils'; -import { - TogetherAIImageModelId, - TogetherAIImageSettings, -} from './togetherai-image-settings'; +import { TogetherAIImageModelId } from './togetherai-image-settings'; import { z } from 'zod'; interface TogetherAIImageModelConfig { @@ -24,18 +21,14 @@ interface TogetherAIImageModelConfig { export class TogetherAIImageModel implements ImageModelV2 { readonly specificationVersion = 'v2'; + readonly maxImagesPerCall = 1; get provider(): string { return this.config.provider; } - get maxImagesPerCall(): number { - return this.settings.maxImagesPerCall ?? 1; - } - constructor( readonly modelId: TogetherAIImageModelId, - readonly settings: TogetherAIImageSettings, private config: TogetherAIImageModelConfig, ) {} diff --git a/packages/togetherai/src/togetherai-image-settings.ts b/packages/togetherai/src/togetherai-image-settings.ts index 53e704316b6f..7234f43530b0 100644 --- a/packages/togetherai/src/togetherai-image-settings.ts +++ b/packages/togetherai/src/togetherai-image-settings.ts @@ -11,10 +11,3 @@ export type TogetherAIImageModelId = | 'black-forest-labs/FLUX.1-pro' | 'black-forest-labs/FLUX.1-schnell-Free' | (string & {}); - -export interface TogetherAIImageSettings { - /** -Override the maximum number of images per call (default 1) - */ - maxImagesPerCall?: number; -} diff --git a/packages/togetherai/src/togetherai-provider.test.ts b/packages/togetherai/src/togetherai-provider.test.ts index 7fbbfa117f61..0a2a3af2db6a 100644 --- a/packages/togetherai/src/togetherai-provider.test.ts +++ b/packages/togetherai/src/togetherai-provider.test.ts @@ -131,13 +131,11 @@ describe('TogetherAIProvider', () => { it('should construct an image model with correct configuration', () => { const provider = createTogetherAI(); const modelId = 'stabilityai/stable-diffusion-xl'; - const settings = { maxImagesPerCall: 4 }; - const model = provider.image(modelId, settings); + const model = provider.image(modelId); expect(TogetherAIImageModel).toHaveBeenCalledWith( modelId, - settings, expect.objectContaining({ provider: 'togetherai.image', baseURL: 'https://api.together.xyz/v1/', @@ -156,7 +154,6 @@ describe('TogetherAIProvider', () => { expect(TogetherAIImageModel).toHaveBeenCalledWith( modelId, - expect.any(Object), expect.objectContaining({ baseURL: 'https://custom.url/', }), diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index 0e7b0c334639..faa831803c4a 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -18,10 +18,7 @@ import { TogetherAIChatModelId } from './togetherai-chat-options'; import { TogetherAIEmbeddingModelId } from './togetherai-embedding-options'; import { TogetherAICompletionModelId } from './togetherai-completion-options'; import { TogetherAIImageModel } from './togetherai-image-model'; -import { - TogetherAIImageModelId, - TogetherAIImageSettings, -} from './togetherai-image-settings'; +import { TogetherAIImageModelId } from './togetherai-image-settings'; export interface TogetherAIProviderSettings { /** @@ -72,16 +69,14 @@ Creates a text embedding model for text generation. ): EmbeddingModelV2; /** - Creates a model for image generation. - */ - image( - modelId: TogetherAIImageModelId, - settings?: TogetherAIImageSettings, - ): ImageModelV2; +Creates a model for image generation. +@deprecated Use `imageModel` instead. +*/ + image(modelId: TogetherAIImageModelId): ImageModelV2; /** - Creates a model for image generation. - */ +Creates a model for image generation. +*/ imageModel(modelId: TogetherAIImageModelId): ImageModelV2; } @@ -133,11 +128,8 @@ export function createTogetherAI( getCommonModelConfig('embedding'), ); - const createImageModel = ( - modelId: TogetherAIImageModelId, - settings: TogetherAIImageSettings = {}, - ) => - new TogetherAIImageModel(modelId, settings, { + const createImageModel = (modelId: TogetherAIImageModelId) => + new TogetherAIImageModel(modelId, { ...getCommonModelConfig('image'), baseURL: baseURL ?? 'https://api.together.xyz/v1/', }); diff --git a/packages/xai/src/xai-image-settings.ts b/packages/xai/src/xai-image-settings.ts index 767bd2e187fa..375ecb31f691 100644 --- a/packages/xai/src/xai-image-settings.ts +++ b/packages/xai/src/xai-image-settings.ts @@ -1,8 +1 @@ export type XaiImageModelId = 'grok-2-image' | (string & {}); - -export interface XaiImageSettings { - /** -Override the maximum number of images per call. Default is 10. - */ - maxImagesPerCall?: number; -} diff --git a/packages/xai/src/xai-provider.test.ts b/packages/xai/src/xai-provider.test.ts index 9ddf4ffafd1d..f508ebd94845 100644 --- a/packages/xai/src/xai-provider.test.ts +++ b/packages/xai/src/xai-provider.test.ts @@ -107,17 +107,15 @@ describe('xAIProvider', () => { it('should construct an image model with correct configuration', () => { const provider = createXai(); const modelId = 'grok-2-image'; - const settings = { maxImagesPerCall: 3 }; - const model = provider.imageModel(modelId, settings); + const model = provider.imageModel(modelId); expect(model).toBeInstanceOf(OpenAICompatibleImageModel); const constructorCall = OpenAICompatibleImageModelMock.mock.calls[0]; expect(constructorCall[0]).toBe(modelId); - expect(constructorCall[1]).toEqual(settings); - const config = constructorCall[2]; + const config = constructorCall[1]; expect(config.provider).toBe('xai.image'); expect(config.url({ path: '/test-path' })).toBe( 'https://api.x.ai/v1/test-path', @@ -132,7 +130,7 @@ describe('xAIProvider', () => { provider.imageModel(modelId); const constructorCall = OpenAICompatibleImageModelMock.mock.calls[0]; - const config = constructorCall[2]; + const config = constructorCall[1]; expect(config.url({ path: '/test-path' })).toBe( `${customBaseURL}/test-path`, ); @@ -145,7 +143,7 @@ describe('xAIProvider', () => { provider.imageModel('grok-2-image'); const constructorCall = OpenAICompatibleImageModelMock.mock.calls[0]; - const config = constructorCall[2]; + const config = constructorCall[1]; const headers = config.headers(); expect(headers).toMatchObject({ diff --git a/packages/xai/src/xai-provider.ts b/packages/xai/src/xai-provider.ts index e20096135d69..1edd6aba0520 100644 --- a/packages/xai/src/xai-provider.ts +++ b/packages/xai/src/xai-provider.ts @@ -16,7 +16,7 @@ import { } from '@ai-sdk/provider-utils'; import { XaiChatModelId, supportsStructuredOutputs } from './xai-chat-options'; import { XaiErrorData, xaiErrorSchema } from './xai-error'; -import { XaiImageModelId, XaiImageSettings } from './xai-image-settings'; +import { XaiImageModelId } from './xai-image-settings'; const xaiErrorStructure: ProviderErrorStructure = { errorSchema: xaiErrorSchema, @@ -41,16 +41,14 @@ Creates an Xai chat model for text generation. /** Creates an Xai image model for image generation. +@deprecated Use `imageModel` instead. */ - image(modelId: XaiImageModelId, settings?: XaiImageSettings): ImageModelV2; + image(modelId: XaiImageModelId): ImageModelV2; /** Creates an Xai image model for image generation. */ - imageModel( - modelId: XaiImageModelId, - settings?: XaiImageSettings, - ): ImageModelV2; + imageModel(modelId: XaiImageModelId): ImageModelV2; } export interface XaiProviderSettings { @@ -102,11 +100,8 @@ export function createXai(options: XaiProviderSettings = {}): XaiProvider { }); }; - const createImageModel = ( - modelId: XaiImageModelId, - settings: XaiImageSettings = {}, - ) => { - return new OpenAICompatibleImageModel(modelId, settings, { + const createImageModel = (modelId: XaiImageModelId) => { + return new OpenAICompatibleImageModel(modelId, { provider: 'xai.image', url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, From 1f569a8b75c1faaa56925ebf690899f7810b4f9c Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Wed, 30 Apr 2025 20:15:31 -0700 Subject: [PATCH 2/5] avoid type errors --- packages/ai/core/generate-image/generate-image.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/ai/core/generate-image/generate-image.ts b/packages/ai/core/generate-image/generate-image.ts index 54fd9a29f125..25c523529d38 100644 --- a/packages/ai/core/generate-image/generate-image.ts +++ b/packages/ai/core/generate-image/generate-image.ts @@ -112,7 +112,7 @@ Only applicable for HTTP-based providers. // extract maxImagesPerCall from providerOptions as it's not meant to be passed to `doGenerate()` const [ - [providerName, { maxImagesPerCall, ...generateProviderOptions } = {}], + [providerName, { maxImagesPerCall, ...generateProviderOptions } = {}] = [], ] = Object.entries(providerOptions ?? {}); // default to 1 if the model has not specified limits on @@ -142,7 +142,9 @@ Only applicable for HTTP-based providers. size, aspectRatio, seed, - providerOptions: { [providerName]: generateProviderOptions }, + providerOptions: providerName + ? { [providerName]: generateProviderOptions } + : {}, }), ), ), From 8f08e9a0fcca99f153b531d18de65a222b0f1a19 Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Wed, 30 Apr 2025 20:20:12 -0700 Subject: [PATCH 3/5] remove `.only` from test debugging --- packages/luma/src/luma-image-model.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/luma/src/luma-image-model.test.ts b/packages/luma/src/luma-image-model.test.ts index 4599816c33bb..bcdbc838c83d 100644 --- a/packages/luma/src/luma-image-model.test.ts +++ b/packages/luma/src/luma-image-model.test.ts @@ -146,7 +146,7 @@ describe('LumaImageModel', () => { }); }); - it.only('should not pass providerOptions.{pollIntervalMillis,maxPollAttempts}', async () => { + it('should not pass providerOptions.{pollIntervalMillis,maxPollAttempts}', async () => { const model = createBasicModel(); await model.doGenerate({ From ac970723828684731fe5d6eadbca55b90f092bae Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Wed, 30 Apr 2025 21:01:37 -0700 Subject: [PATCH 4/5] update snapshots --- packages/openai/src/openai-completion-language-model.test.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/openai/src/openai-completion-language-model.test.ts b/packages/openai/src/openai-completion-language-model.test.ts index 20a2ccb37367..234e16b88d51 100644 --- a/packages/openai/src/openai-completion-language-model.test.ts +++ b/packages/openai/src/openai-completion-language-model.test.ts @@ -530,6 +530,9 @@ describe('doStream', () => { }, { "finishReason": "error", + "providerMetadata": { + "openai": {}, + }, "type": "finish", "usage": { "inputTokens": undefined, From 3d7bc413b91f53bf430c0b10e4d9b583648fab63 Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Sun, 4 May 2025 09:41:31 -0700 Subject: [PATCH 5/5] test (luma): adapt to changes in `v5` --- packages/luma/src/luma-image-model.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/luma/src/luma-image-model.test.ts b/packages/luma/src/luma-image-model.test.ts index 35689fb7a281..3e95aa0eb6df 100644 --- a/packages/luma/src/luma-image-model.test.ts +++ b/packages/luma/src/luma-image-model.test.ts @@ -164,7 +164,7 @@ describe('LumaImageModel', () => { }, }); - expect(await server.calls[0].requestBody).toStrictEqual({ + expect(await server.calls[0].requestBodyJson).toStrictEqual({ prompt, aspect_ratio: '16:9', model: 'test-model',