Skip to content

Commit a8c8bd5

Browse files
samdentylgrammel
andauthored
feat(embed-many): respect supportsParallelCalls & concurrency (vercel#6108)
## Background We didn't actually read the supportsParallelCalls field at all, and we did everything serially in embedding model ## Summary This makes the embedding model actually respect supportsParallel calls Co-authored-by: Lars Grammel <[email protected]>
1 parent 41fa418 commit a8c8bd5

File tree

5 files changed

+268
-104
lines changed

5 files changed

+268
-104
lines changed

.changeset/hungry-hotels-hunt.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
'@ai-sdk/amazon-bedrock': patch
3+
'@ai-sdk/provider': patch
4+
'ai': patch
5+
---
6+
7+
feat(embed-many): respect supportsParallelCalls & concurrency

packages/ai/core/embed/embed-many.test.ts

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
} from '../test/mock-embedding-model-v2';
66
import { MockTracer } from '../test/mock-tracer';
77
import { embedMany } from './embed-many';
8+
import { createResolvablePromise } from '../../util/create-resolvable-promise';
89

910
const dummyEmbeddings = [
1011
[0.1, 0.2, 0.3],
@@ -18,6 +19,153 @@ const testValues = [
1819
'snowy night in the mountains',
1920
];
2021

22+
describe('model.supportsParallelCalls', () => {
23+
it('should not parallelize when false', async () => {
24+
const events: string[] = [];
25+
let callCount = 0;
26+
27+
const resolvables = [
28+
createResolvablePromise<void>(),
29+
createResolvablePromise<void>(),
30+
createResolvablePromise<void>(),
31+
];
32+
33+
const embedManyPromise = embedMany({
34+
model: new MockEmbeddingModelV2({
35+
supportsParallelCalls: false,
36+
maxEmbeddingsPerCall: 1,
37+
doEmbed: async () => {
38+
const index = callCount++;
39+
events.push(`start-${index}`);
40+
41+
await resolvables[index].promise;
42+
events.push(`end-${index}`);
43+
44+
return {
45+
embeddings: [dummyEmbeddings[index]],
46+
response: { headers: {}, body: {} },
47+
};
48+
},
49+
}),
50+
values: testValues,
51+
});
52+
53+
resolvables.forEach(resolvable => {
54+
resolvable.resolve();
55+
});
56+
57+
const { embeddings } = await embedManyPromise;
58+
59+
expect(events).toStrictEqual([
60+
'start-0',
61+
'end-0',
62+
'start-1',
63+
'end-1',
64+
'start-2',
65+
'end-2',
66+
]);
67+
68+
expect(embeddings).toStrictEqual(dummyEmbeddings);
69+
});
70+
71+
it('should parallelize when true', async () => {
72+
const events: string[] = [];
73+
let callCount = 0;
74+
75+
const resolvables = [
76+
createResolvablePromise<void>(),
77+
createResolvablePromise<void>(),
78+
createResolvablePromise<void>(),
79+
];
80+
81+
const embedManyPromise = embedMany({
82+
model: new MockEmbeddingModelV2({
83+
supportsParallelCalls: true,
84+
maxEmbeddingsPerCall: 1,
85+
doEmbed: async () => {
86+
const index = callCount++;
87+
events.push(`start-${index}`);
88+
89+
await resolvables[index].promise;
90+
events.push(`end-${index}`);
91+
92+
return {
93+
embeddings: [dummyEmbeddings[index]],
94+
response: { headers: {}, body: {} },
95+
};
96+
},
97+
}),
98+
values: testValues,
99+
});
100+
101+
resolvables.forEach(resolvable => {
102+
resolvable.resolve();
103+
});
104+
105+
const { embeddings } = await embedManyPromise;
106+
107+
expect(events).toStrictEqual([
108+
'start-0',
109+
'start-1',
110+
'start-2',
111+
'end-0',
112+
'end-1',
113+
'end-2',
114+
]);
115+
116+
expect(embeddings).toStrictEqual(dummyEmbeddings);
117+
});
118+
119+
it('should support maxParallelCalls', async () => {
120+
const events: string[] = [];
121+
let callCount = 0;
122+
123+
const resolvables = [
124+
createResolvablePromise<void>(),
125+
createResolvablePromise<void>(),
126+
createResolvablePromise<void>(),
127+
];
128+
129+
const embedManyPromise = embedMany({
130+
maxParallelCalls: 2,
131+
model: new MockEmbeddingModelV2({
132+
supportsParallelCalls: true,
133+
maxEmbeddingsPerCall: 1,
134+
doEmbed: async () => {
135+
const index = callCount++;
136+
events.push(`start-${index}`);
137+
138+
await resolvables[index].promise;
139+
events.push(`end-${index}`);
140+
141+
return {
142+
embeddings: [dummyEmbeddings[index]],
143+
response: { headers: {}, body: {} },
144+
};
145+
},
146+
}),
147+
values: testValues,
148+
});
149+
150+
resolvables.forEach(resolvable => {
151+
resolvable.resolve();
152+
});
153+
154+
const { embeddings } = await embedManyPromise;
155+
156+
expect(events).toStrictEqual([
157+
'start-0',
158+
'start-1',
159+
'end-0',
160+
'end-1',
161+
'start-2',
162+
'end-2',
163+
]);
164+
165+
expect(embeddings).toStrictEqual(dummyEmbeddings);
166+
});
167+
});
168+
21169
describe('result.embedding', () => {
22170
it('should generate embeddings', async () => {
23171
const result = await embedMany({

packages/ai/core/embed/embed-many.ts

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ has a limit on how many embeddings can be generated in a single call.
2828
export async function embedMany<VALUE>({
2929
model,
3030
values,
31+
maxParallelCalls = Infinity,
3132
maxRetries: maxRetriesArg,
3233
abortSignal,
3334
headers,
@@ -73,6 +74,13 @@ Only applicable for HTTP-based providers.
7374
functionality that can be fully encapsulated in the provider.
7475
*/
7576
providerOptions?: ProviderOptions;
77+
78+
/**
79+
* Maximum number of concurrent requests.
80+
*
81+
* @default Infinity
82+
*/
83+
maxParallelCalls?: number;
7684
}): Promise<EmbedManyResult<VALUE>> {
7785
const { maxRetries, retry } = prepareRetries({ maxRetries: maxRetriesArg });
7886

@@ -100,7 +108,10 @@ Only applicable for HTTP-based providers.
100108
}),
101109
tracer,
102110
fn: async span => {
103-
const maxEmbeddingsPerCall = await model.maxEmbeddingsPerCall;
111+
const [maxEmbeddingsPerCall, supportsParallelCalls] = await Promise.all([
112+
model.maxEmbeddingsPerCall,
113+
model.supportsParallelCalls,
114+
]);
104115

105116
// the model has not specified limits on
106117
// how many embeddings can be generated in a single call
@@ -192,66 +203,75 @@ Only applicable for HTTP-based providers.
192203
> = [];
193204
let tokens = 0;
194205

195-
for (const chunk of valueChunks) {
196-
const {
197-
embeddings: responseEmbeddings,
198-
usage,
199-
response,
200-
} = await retry(() => {
201-
// nested spans to align with the embedMany telemetry data:
202-
return recordSpan({
203-
name: 'ai.embedMany.doEmbed',
204-
attributes: selectTelemetryAttributes({
205-
telemetry,
206-
attributes: {
207-
...assembleOperationName({
208-
operationId: 'ai.embedMany.doEmbed',
209-
telemetry,
210-
}),
211-
...baseTelemetryAttributes,
212-
// specific settings that only make sense on the outer level:
213-
'ai.values': {
214-
input: () => chunk.map(value => JSON.stringify(value)),
215-
},
216-
},
217-
}),
218-
tracer,
219-
fn: async doEmbedSpan => {
220-
const modelResponse = await model.doEmbed({
221-
values: chunk,
222-
abortSignal,
223-
headers,
224-
providerOptions,
225-
});
226-
227-
const embeddings = modelResponse.embeddings;
228-
const usage = modelResponse.usage ?? { tokens: NaN };
206+
const parallelChunks = splitArray(
207+
valueChunks,
208+
supportsParallelCalls ? maxParallelCalls : 1,
209+
);
229210

230-
doEmbedSpan.setAttributes(
231-
selectTelemetryAttributes({
211+
for (const parallelChunk of parallelChunks) {
212+
const results = await Promise.all(
213+
parallelChunk.map(chunk => {
214+
return retry(() => {
215+
// nested spans to align with the embedMany telemetry data:
216+
return recordSpan({
217+
name: 'ai.embedMany.doEmbed',
218+
attributes: selectTelemetryAttributes({
232219
telemetry,
233220
attributes: {
234-
'ai.embeddings': {
235-
output: () =>
236-
embeddings.map(embedding => JSON.stringify(embedding)),
221+
...assembleOperationName({
222+
operationId: 'ai.embedMany.doEmbed',
223+
telemetry,
224+
}),
225+
...baseTelemetryAttributes,
226+
// specific settings that only make sense on the outer level:
227+
'ai.values': {
228+
input: () => chunk.map(value => JSON.stringify(value)),
237229
},
238-
'ai.usage.tokens': usage.tokens,
239230
},
240231
}),
241-
);
242-
243-
return {
244-
embeddings,
245-
usage,
246-
response: modelResponse.response,
247-
};
248-
},
249-
});
250-
});
232+
tracer,
233+
fn: async doEmbedSpan => {
234+
const modelResponse = await model.doEmbed({
235+
values: chunk,
236+
abortSignal,
237+
headers,
238+
providerOptions,
239+
});
240+
241+
const embeddings = modelResponse.embeddings;
242+
const usage = modelResponse.usage ?? { tokens: NaN };
243+
244+
doEmbedSpan.setAttributes(
245+
selectTelemetryAttributes({
246+
telemetry,
247+
attributes: {
248+
'ai.embeddings': {
249+
output: () =>
250+
embeddings.map(embedding =>
251+
JSON.stringify(embedding),
252+
),
253+
},
254+
'ai.usage.tokens': usage.tokens,
255+
},
256+
}),
257+
);
258+
259+
return {
260+
embeddings,
261+
usage,
262+
response: modelResponse.response,
263+
};
264+
},
265+
});
266+
});
267+
}),
268+
);
251269

252-
embeddings.push(...responseEmbeddings);
253-
responses.push(response);
254-
tokens += usage.tokens;
270+
for (const result of results) {
271+
embeddings.push(...result.embeddings);
272+
responses.push(result.response);
273+
tokens += result.usage.tokens;
274+
}
255275
}
256276

257277
span.setAttributes(

packages/amazon-bedrock/src/bedrock-embedding-model.test.ts

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,6 @@ describe('doEmbed', () => {
8686
expect(usage?.tokens).toStrictEqual(8);
8787
});
8888

89-
it('should handle multiple input values and extract usage', async () => {
90-
const { usage } = await model.doEmbed({
91-
values: testValues,
92-
});
93-
94-
expect(usage?.tokens).toStrictEqual(16);
95-
});
96-
9789
it('should properly combine headers from all sources', async () => {
9890
const optionsHeaders = {
9991
'options-header': 'options-value',

0 commit comments

Comments
 (0)