Skip to content

fix (ai)!: move image model settings into generateImage({providerOptions}) #6083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
2 changes: 2 additions & 0 deletions packages/ai/core/generate-image/generate-image.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ describe('generateImage', () => {
providerOptions: {
'mock-provider': {
style: 'vivid',
// maxImagesPerCall is not passed to doGenerate
maxImagesPerCall: 3,
},
},
headers: { 'custom-request-header': 'request-header-value' },
Expand Down
21 changes: 15 additions & 6 deletions packages/ai/core/generate-image/generate-image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,27 @@ Only applicable for HTTP-based providers.
}): Promise<GenerateImageResult> {
const { retry } = prepareRetries({ maxRetries: maxRetriesArg });

// extract maxImagesPerCall from providerOptions as it's not meant to be passed to `doGenerate()`
const [
[providerName, { maxImagesPerCall, ...generateProviderOptions } = {}] = [],
] = Object.entries(providerOptions ?? {});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that we don't know providerName at the time when generateImage() is called. Can't we somehow reliable extract it from model?

I also don't know if this change breaks a use case that I'm not aware of?


// default to 1 if the model has not specified limits on
// how many images can be generated in a single call
const maxImagesPerCall = model.maxImagesPerCall ?? 1;
const maxImagesPerCallWithDefault =
(maxImagesPerCall as number) ?? model.maxImagesPerCall ?? 1;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the main change. Before we had access to model.settings in the maxImagesPerCall getter. Now that settings has been removed from the ImageModelV2 constructor, that is no longer the case.


// parallelize calls to the model:
const callCount = Math.ceil(n / maxImagesPerCall);
const callCount = Math.ceil(n / maxImagesPerCallWithDefault);
const callImageCounts = Array.from({ length: callCount }, (_, i) => {
if (i < callCount - 1) {
return maxImagesPerCall;
return maxImagesPerCallWithDefault;
}

const remainder = n % maxImagesPerCall;
return remainder === 0 ? maxImagesPerCall : remainder;
const remainder = n % maxImagesPerCallWithDefault;
return remainder === 0 ? maxImagesPerCallWithDefault : remainder;
});

const results = await Promise.all(
callImageCounts.map(async callImageCount =>
retry(() =>
Expand All @@ -135,7 +142,9 @@ Only applicable for HTTP-based providers.
size,
aspectRatio,
seed,
providerOptions: providerOptions ?? {},
providerOptions: providerName
? { [providerName]: generateProviderOptions }
: {},
Comment on lines +145 to +147
Copy link
Collaborator Author

@gr2m gr2m May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a use case where providerOptions has more than one key? If so this change would break it. My current change doesn't pass through provider options because we want to remove the maxImagesPerCall option first, otherwise it ends up in the API request body.

}),
),
),
Expand Down
59 changes: 21 additions & 38 deletions packages/amazon-bedrock/src/bedrock-image-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ describe('doGenerate', () => {
},
});

const model = new BedrockImageModel(
'amazon.nova-canvas-v1:0',
{},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: mockConfigHeaders,
fetch: fakeFetchWithAuth,
},
);
const model = new BedrockImageModel('amazon.nova-canvas-v1:0', {
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: mockConfigHeaders,
fetch: fakeFetchWithAuth,
});

it('should pass the model and the settings', async () => {
await model.doGenerate({
Expand Down Expand Up @@ -83,21 +79,17 @@ describe('doGenerate', () => {
'shared-header': 'options-shared',
};

const modelWithHeaders = new BedrockImageModel(
'amazon.nova-canvas-v1:0',
{},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {
'model-header': 'model-value',
'shared-header': 'model-shared',
},
fetch: injectFetchHeaders({
'signed-header': 'signed-value',
authorization: 'AWS4-HMAC-SHA256...',
}),
const modelWithHeaders = new BedrockImageModel('amazon.nova-canvas-v1:0', {
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {
'model-header': 'model-value',
'shared-header': 'model-shared',
},
);
fetch: injectFetchHeaders({
'signed-header': 'signed-value',
authorization: 'AWS4-HMAC-SHA256...',
}),
});

await modelWithHeaders.doGenerate({
prompt,
Expand All @@ -118,11 +110,6 @@ describe('doGenerate', () => {
});

it('should respect maxImagesPerCall setting', async () => {
const customModel = provider.image('amazon.nova-canvas-v1:0', {
maxImagesPerCall: 2,
});
expect(customModel.maxImagesPerCall).toBe(2);

const defaultModel = provider.image('amazon.nova-canvas-v1:0');
expect(defaultModel.maxImagesPerCall).toBe(5); // 'amazon.nova-canvas-v1:0','s default from settings

Expand Down Expand Up @@ -166,17 +153,13 @@ describe('doGenerate', () => {
it('should include response data with timestamp, modelId and headers', async () => {
const testDate = new Date('2024-03-15T12:00:00Z');

const customModel = new BedrockImageModel(
'amazon.nova-canvas-v1:0',
{},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: () => ({}),
_internal: {
currentDate: () => testDate,
},
const customModel = new BedrockImageModel('amazon.nova-canvas-v1:0', {
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: () => ({}),
_internal: {
currentDate: () => testDate,
},
);
});

const result = await customModel.doGenerate({
prompt,
Expand Down
6 changes: 1 addition & 5 deletions packages/amazon-bedrock/src/bedrock-image-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
} from '@ai-sdk/provider-utils';
import {
BedrockImageModelId,
BedrockImageSettings,
modelMaxImagesPerCall,
} from './bedrock-image-settings';
import { BedrockErrorSchema } from './bedrock-error';
Expand All @@ -30,9 +29,7 @@ export class BedrockImageModel implements ImageModelV2 {
readonly provider = 'amazon-bedrock';

get maxImagesPerCall(): number {
return (
this.settings.maxImagesPerCall ?? modelMaxImagesPerCall[this.modelId] ?? 1
);
return modelMaxImagesPerCall[this.modelId] ?? 1;
}

private getUrl(modelId: string): string {
Expand All @@ -42,7 +39,6 @@ export class BedrockImageModel implements ImageModelV2 {

constructor(
readonly modelId: BedrockImageModelId,
private readonly settings: BedrockImageSettings,
private readonly config: BedrockImageModelConfig,
) {}

Expand Down
8 changes: 0 additions & 8 deletions packages/amazon-bedrock/src/bedrock-image-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,3 @@ export type BedrockImageModelId = 'amazon.nova-canvas-v1:0' | (string & {});
export const modelMaxImagesPerCall: Record<BedrockImageModelId, number> = {
'amazon.nova-canvas-v1:0': 5,
};

export interface BedrockImageSettings {
/**
* Override the maximum number of images per call (default is dependent on the
* model, or 1 for an unknown model).
*/
maxImagesPerCall?: number;
}
10 changes: 2 additions & 8 deletions packages/amazon-bedrock/src/bedrock-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,27 +140,21 @@ describe('AmazonBedrockProvider', () => {
const provider = createAmazonBedrock();
const modelId = 'amazon.titan-image-generator';

const model = provider.image(modelId, {
maxImagesPerCall: 5,
});
const model = provider.image(modelId);

const constructorCall = BedrockImageModelMock.mock.calls[0];
expect(constructorCall[0]).toBe(modelId);
expect(constructorCall[1]).toEqual({ maxImagesPerCall: 5 });
expect(model).toBeInstanceOf(BedrockImageModel);
});

it('should create an image model via imageModel method', () => {
const provider = createAmazonBedrock();
const modelId = 'amazon.titan-image-generator';

const model = provider.imageModel(modelId, {
maxImagesPerCall: 5,
});
const model = provider.imageModel(modelId);

const constructorCall = BedrockImageModelMock.mock.calls[0];
expect(constructorCall[0]).toBe(modelId);
expect(constructorCall[1]).toEqual({ maxImagesPerCall: 5 });
expect(model).toBeInstanceOf(BedrockImageModel);
});
});
Expand Down
29 changes: 12 additions & 17 deletions packages/amazon-bedrock/src/bedrock-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ import { BedrockChatModelId } from './bedrock-chat-options';
import { BedrockEmbeddingModel } from './bedrock-embedding-model';
import { BedrockEmbeddingModelId } from './bedrock-embedding-options';
import { BedrockImageModel } from './bedrock-image-model';
import {
BedrockImageModelId,
BedrockImageSettings,
} from './bedrock-image-settings';
import { BedrockImageModelId } from './bedrock-image-settings';
import {
BedrockCredentials,
createSigV4FetchFunction,
Expand Down Expand Up @@ -85,15 +82,16 @@ export interface AmazonBedrockProvider extends ProviderV2 {

embedding(modelId: BedrockEmbeddingModelId): EmbeddingModelV2<string>;

image(
modelId: BedrockImageModelId,
settings?: BedrockImageSettings,
): ImageModelV2;
/**
Creates a model for image generation.
@deprecated Use `imageModel` instead.
*/
image(modelId: BedrockImageModelId): ImageModelV2;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image is deprecated, right? Is saw it in one of the models and applied it everywhere. Shall we remove it as part of v5 (in another PR?)


imageModel(
modelId: BedrockImageModelId,
settings?: BedrockImageSettings,
): ImageModelV2;
/**
Creates a model for image generation.
*/
imageModel(modelId: BedrockImageModelId): ImageModelV2;
}

/**
Expand Down Expand Up @@ -173,11 +171,8 @@ export function createAmazonBedrock(
fetch: sigv4Fetch,
});

const createImageModel = (
modelId: BedrockImageModelId,
settings: BedrockImageSettings = {},
) =>
new BedrockImageModel(modelId, settings, {
const createImageModel = (modelId: BedrockImageModelId) =>
new BedrockImageModel(modelId, {
baseUrl: getBaseUrl,
headers: options.headers ?? {},
fetch: sigv4Fetch,
Expand Down
15 changes: 4 additions & 11 deletions packages/azure/src/azure-openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
OpenAICompletionLanguageModel,
OpenAIEmbeddingModel,
OpenAIImageModel,
OpenAIImageSettings,
OpenAIResponsesLanguageModel,
OpenAITranscriptionModel,
} from '@ai-sdk/openai/internal';
Expand Down Expand Up @@ -48,15 +47,12 @@ Creates an Azure OpenAI completion model for text generation.
* Creates an Azure OpenAI DALL-E model for image generation.
* @deprecated Use `imageModel` instead.
*/
image(deploymentId: string, settings?: OpenAIImageSettings): ImageModelV2;
image(deploymentId: string): ImageModelV2;

/**
* Creates an Azure OpenAI DALL-E model for image generation.
*/
imageModel(
deploymentId: string,
settings?: OpenAIImageSettings,
): ImageModelV2;
imageModel(deploymentId: string): ImageModelV2;

/**
@deprecated Use `textEmbeddingModel` instead.
Expand Down Expand Up @@ -181,11 +177,8 @@ export function createAzure(
fetch: options.fetch,
});

const createImageModel = (
modelId: string,
settings: OpenAIImageSettings = {},
) =>
new OpenAIImageModel(modelId, settings, {
const createImageModel = (modelId: string) =>
new OpenAIImageModel(modelId, {
provider: 'azure.image',
url,
headers: getHeaders,
Expand Down
21 changes: 1 addition & 20 deletions packages/deepinfra/src/deepinfra-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { createTestServer } from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { DeepInfraImageModel } from './deepinfra-image-model';
import { DeepInfraImageSettings } from './deepinfra-image-settings';
import { FetchFunction } from '@ai-sdk/provider-utils';

const prompt = 'A cute baby sea otter';
Expand All @@ -10,14 +9,12 @@ function createBasicModel({
headers,
fetch,
currentDate,
settings,
}: {
headers?: () => Record<string, string>;
fetch?: FetchFunction;
currentDate?: () => Date;
settings?: DeepInfraImageSettings;
} = {}) {
return new DeepInfraImageModel('stability-ai/sdxl', settings ?? {}, {
return new DeepInfraImageModel('stability-ai/sdxl', {
provider: 'deepinfra',
baseURL: 'https://api.example.com',
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
Expand Down Expand Up @@ -234,21 +231,5 @@ describe('DeepInfraImageModel', () => {
expect(model.specificationVersion).toBe('v2');
expect(model.maxImagesPerCall).toBe(1);
});

it('should use maxImagesPerCall from settings', () => {
const model = createBasicModel({
settings: {
maxImagesPerCall: 4,
},
});

expect(model.maxImagesPerCall).toBe(4);
});

it('should default maxImagesPerCall to 1 when not specified', () => {
const model = createBasicModel();

expect(model.maxImagesPerCall).toBe(1);
});
});
});
11 changes: 2 additions & 9 deletions packages/deepinfra/src/deepinfra-image-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import {
createJsonResponseHandler,
postJsonToApi,
} from '@ai-sdk/provider-utils';
import {
DeepInfraImageModelId,
DeepInfraImageSettings,
} from './deepinfra-image-settings';
import { DeepInfraImageModelId } from './deepinfra-image-settings';
import { z } from 'zod';

interface DeepInfraImageModelConfig {
Expand All @@ -24,18 +21,14 @@ interface DeepInfraImageModelConfig {

export class DeepInfraImageModel implements ImageModelV2 {
readonly specificationVersion = 'v2';
readonly maxImagesPerCall = 1;

get provider(): string {
return this.config.provider;
}

get maxImagesPerCall(): number {
return this.settings.maxImagesPerCall ?? 1;
}

constructor(
readonly modelId: DeepInfraImageModelId,
readonly settings: DeepInfraImageSettings,
private config: DeepInfraImageModelConfig,
) {}

Expand Down
7 changes: 0 additions & 7 deletions packages/deepinfra/src/deepinfra-image-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,3 @@ export type DeepInfraImageModelId =
| 'stabilityai/sd3.5-medium'
| 'stabilityai/sdxl-turbo'
| (string & {});

export interface DeepInfraImageSettings {
/**
* Override the maximum number of images per call (default 1)
*/
maxImagesPerCall?: number;
}
Loading
Loading