Skip to content

Commit 0517afc

Browse files
committed
[llm] Fix start_pos not being updated in prefill_chunk()
1 parent 72a095f commit 0517afc

File tree

4 files changed

+367
-5
lines changed

4 files changed

+367
-5
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8+
*/
9+
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
#include <executorch/extension/llm/runner/text_prefiller.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <gmock/gmock.h>
14+
#include <gtest/gtest.h>
15+
16+
using namespace ::testing;
17+
using executorch::extension::llm::TextDecoderRunner;
18+
using executorch::extension::llm::TextPrefiller;
19+
using executorch::runtime::Error;
20+
using executorch::runtime::Result;
21+
using executorch::runtime::testing::TensorFactory;
22+
23+
// Mock class for TextDecoderRunner
24+
class MockTextDecoderRunner : public TextDecoderRunner {
25+
public:
26+
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
27+
MOCK_METHOD(
28+
Result<executorch::aten::Tensor>,
29+
step,
30+
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
31+
());
32+
MOCK_METHOD(bool, is_method_loaded, (), ());
33+
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
34+
MOCK_METHOD(::executorch::runtime::Error, load, (), ());
35+
MOCK_METHOD(uint64_t, logits_to_token, (const executorch::aten::Tensor&), ());
36+
};
37+
38+
// Test fixture for TextPrefiller tests
39+
class TextPrefillerTest : public Test {
40+
protected:
41+
void SetUp() override {
42+
// Set up default behavior for the text decoder runner
43+
ON_CALL(text_decoder_runner_, is_method_loaded())
44+
.WillByDefault(Return(true));
45+
ON_CALL(text_decoder_runner_, step)
46+
.WillByDefault([&](executorch::extension::TensorPtr&,
47+
executorch::extension::TensorPtr&) {
48+
return Result<executorch::aten::Tensor>(tensor);
49+
});
50+
ON_CALL(text_decoder_runner_, logits_to_token)
51+
.WillByDefault([](const executorch::aten::Tensor&) { return 42; });
52+
}
53+
54+
// Helper method to create a TextPrefiller with specific parameters
55+
std::unique_ptr<TextPrefiller> createTextPrefiller(
56+
int64_t max_seq_len,
57+
bool use_kv_cache = true,
58+
bool enable_parallel_prefill = false) {
59+
return std::make_unique<TextPrefiller>(
60+
&text_decoder_runner_,
61+
use_kv_cache,
62+
enable_parallel_prefill,
63+
max_seq_len);
64+
}
65+
66+
// Create a mock TextPrefiller that allows us to spy on prefill_chunk calls
67+
class SpyTextPrefiller : public TextPrefiller {
68+
public:
69+
SpyTextPrefiller(
70+
TextDecoderRunner* text_decoder_runner,
71+
bool use_kv_cache,
72+
bool enable_parallel_prefill,
73+
int64_t max_seq_len)
74+
: TextPrefiller(
75+
text_decoder_runner,
76+
use_kv_cache,
77+
enable_parallel_prefill,
78+
max_seq_len) {}
79+
80+
MOCK_METHOD(
81+
::executorch::runtime::Result<uint64_t>,
82+
prefill_chunk,
83+
(std::vector<uint64_t>&, int64_t&),
84+
());
85+
86+
// Call the real implementation after recording the call
87+
::executorch::runtime::Result<uint64_t> prefill_chunk_impl(
88+
std::vector<uint64_t>& prompt_tokens,
89+
int64_t& start_pos) {
90+
return TextPrefiller::prefill_chunk(prompt_tokens, start_pos);
91+
}
92+
};
93+
94+
// Create a spy TextPrefiller
95+
std::unique_ptr<SpyTextPrefiller> createSpyTextPrefiller(
96+
int64_t max_seq_len,
97+
bool use_kv_cache = true,
98+
bool enable_parallel_prefill = false) {
99+
auto prefiller = std::make_unique<SpyTextPrefiller>(
100+
&text_decoder_runner_,
101+
use_kv_cache,
102+
enable_parallel_prefill,
103+
max_seq_len);
104+
105+
// Set up the spy to call the real implementation
106+
ON_CALL(*prefiller, prefill_chunk)
107+
.WillByDefault(
108+
[prefiller](std::vector<uint64_t>& tokens, int64_t& pos) {
109+
return prefiller->prefill_chunk_impl(tokens, pos);
110+
});
111+
112+
return prefiller;
113+
}
114+
115+
MockTextDecoderRunner text_decoder_runner_;
116+
std::vector<float> return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f};
117+
TensorFactory<executorch::aten::ScalarType::Float> tf;
118+
executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_);
119+
};
120+
121+
// Test that prefill() calls prefill_chunk() once when prompt tokens <=
122+
// max_seq_len
123+
TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
124+
// Create a spy TextPrefiller with max_seq_len = 10
125+
auto prefiller = createSpyTextPrefiller(10);
126+
127+
// Create prompt tokens with size <= max_seq_len
128+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5};
129+
int64_t start_pos = 0;
130+
131+
// Expect prefill_chunk to be called exactly once with the entire prompt
132+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
133+
.Times(1)
134+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
135+
// Verify the tokens passed to prefill_chunk
136+
EXPECT_EQ(tokens.size(), prompt_tokens.size());
137+
for (size_t i = 0; i < tokens.size(); i++) {
138+
EXPECT_EQ(tokens[i], prompt_tokens[i]);
139+
}
140+
// Verify the position
141+
EXPECT_EQ(pos, start_pos);
142+
return Result<uint64_t>(42);
143+
});
144+
145+
// Call prefill
146+
auto result = prefiller->prefill(prompt_tokens, start_pos);
147+
148+
// Verify the result
149+
EXPECT_EQ(result.error(), Error::Ok);
150+
EXPECT_EQ(result.get(), 42);
151+
}
152+
153+
// Test that prefill() calls prefill_chunk() multiple times when prompt tokens >
154+
// max_seq_len
155+
TEST_F(
156+
TextPrefillerTest,
157+
PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
158+
// Create a spy TextPrefiller with max_seq_len = 3
159+
const int64_t max_seq_len = 3;
160+
auto prefiller = createSpyTextPrefiller(max_seq_len);
161+
162+
// Create prompt tokens with size > max_seq_len
163+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
164+
int64_t start_pos = 0;
165+
166+
// Calculate expected number of chunks
167+
const int expected_chunks =
168+
(prompt_tokens.size() + max_seq_len - 1) / max_seq_len;
169+
170+
// Set up expectations for prefill_chunk calls
171+
{
172+
InSequence seq; // Ensure calls happen in the expected order
173+
174+
// First chunk: tokens [1, 2, 3]
175+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
176+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
177+
EXPECT_EQ(tokens.size(), 3);
178+
EXPECT_EQ(tokens[0], 1);
179+
EXPECT_EQ(tokens[1], 2);
180+
EXPECT_EQ(tokens[2], 3);
181+
EXPECT_EQ(pos, 0);
182+
return Result<uint64_t>(10);
183+
});
184+
185+
// Second chunk: tokens [4, 5, 6]
186+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
187+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
188+
EXPECT_EQ(tokens.size(), 3);
189+
EXPECT_EQ(tokens[0], 4);
190+
EXPECT_EQ(tokens[1], 5);
191+
EXPECT_EQ(tokens[2], 6);
192+
EXPECT_EQ(pos, 3);
193+
return Result<uint64_t>(20);
194+
});
195+
196+
// Third chunk: tokens [7, 8]
197+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
198+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
199+
EXPECT_EQ(tokens.size(), 2);
200+
EXPECT_EQ(tokens[0], 7);
201+
EXPECT_EQ(tokens[1], 8);
202+
EXPECT_EQ(pos, 6);
203+
return Result<uint64_t>(30);
204+
});
205+
}
206+
207+
// Call prefill
208+
auto result = prefiller->prefill(prompt_tokens, start_pos);
209+
210+
// Verify the result
211+
EXPECT_EQ(result.error(), Error::Ok);
212+
EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk
213+
214+
// Verify that start_pos has been updated correctly
215+
EXPECT_EQ(start_pos, prompt_tokens.size());
216+
}
217+
218+
// Test that prefill() handles edge cases correctly
219+
TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
220+
// Create a spy TextPrefiller with max_seq_len = 1
221+
const int64_t max_seq_len = 1;
222+
auto prefiller = createSpyTextPrefiller(max_seq_len);
223+
224+
// Create prompt tokens with size > max_seq_len
225+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
226+
int64_t start_pos = 5; // Non-zero starting position
227+
228+
// Set up expectations for prefill_chunk calls
229+
{
230+
InSequence seq;
231+
232+
// First chunk: token [1]
233+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
234+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
235+
EXPECT_EQ(tokens.size(), 1);
236+
EXPECT_EQ(tokens[0], 1);
237+
EXPECT_EQ(pos, 5);
238+
return Result<uint64_t>(10);
239+
});
240+
241+
// Second chunk: token [2]
242+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
243+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
244+
EXPECT_EQ(tokens.size(), 1);
245+
EXPECT_EQ(tokens[0], 2);
246+
EXPECT_EQ(pos, 6);
247+
return Result<uint64_t>(20);
248+
});
249+
250+
// Third chunk: token [3]
251+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
252+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
253+
EXPECT_EQ(tokens.size(), 1);
254+
EXPECT_EQ(tokens[0], 3);
255+
EXPECT_EQ(pos, 7);
256+
return Result<uint64_t>(30);
257+
});
258+
}
259+
260+
// Call prefill
261+
auto result = prefiller->prefill(prompt_tokens, start_pos);
262+
263+
// Verify the result
264+
EXPECT_EQ(result.error(), Error::Ok);
265+
EXPECT_EQ(result.get(), 30);
266+
267+
// Verify that start_pos has been updated correctly
268+
EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens)
269+
}
270+
271+
// Test that prefill() handles errors from prefill_chunk correctly
272+
TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
273+
// Create a spy TextPrefiller with max_seq_len = 3
274+
const int64_t max_seq_len = 3;
275+
auto prefiller = createSpyTextPrefiller(max_seq_len);
276+
277+
// Create prompt tokens with size > max_seq_len
278+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5};
279+
int64_t start_pos = 0;
280+
281+
// Set up expectations for prefill_chunk calls
282+
{
283+
InSequence seq;
284+
285+
// First chunk: tokens [1, 2, 3] - succeeds
286+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
287+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
288+
return Result<uint64_t>(10);
289+
});
290+
291+
// Second chunk: tokens [4, 5] - fails
292+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
293+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
294+
return Result<uint64_t>(Error::InvalidArgument);
295+
});
296+
}
297+
298+
// Call prefill
299+
auto result = prefiller->prefill(prompt_tokens, start_pos);
300+
301+
// Verify that the error is propagated
302+
EXPECT_EQ(result.error(), Error::InvalidArgument);
303+
}
304+
305+
// Test that prefill_chunk() works correctly with parallel prefill enabled
306+
TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
307+
// Create a TextPrefiller with parallel prefill enabled
308+
auto prefiller = createTextPrefiller(10, true, true);
309+
310+
// Set up expectations for the text decoder runner
311+
EXPECT_CALL(text_decoder_runner_, step(_, _))
312+
.Times(1)
313+
.WillOnce(Return(Result<executorch::aten::Tensor>(tensor)));
314+
315+
EXPECT_CALL(text_decoder_runner_, logits_to_token(_))
316+
.Times(1)
317+
.WillOnce(Return(42));
318+
319+
// Create prompt tokens
320+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
321+
int64_t start_pos = 0;
322+
323+
// Call prefill
324+
auto result = prefiller->prefill(prompt_tokens, start_pos);
325+
326+
// Verify the result
327+
EXPECT_EQ(result.error(), Error::Ok);
328+
EXPECT_EQ(result.get(), 42);
329+
330+
// Verify that start_pos has been updated correctly
331+
EXPECT_EQ(start_pos, prompt_tokens.size());
332+
}
333+
334+
// Test that prefill_chunk() works correctly with sequential prefill
335+
TEST_F(TextPrefillerTest, PrefillChunkWorksWithSequentialPrefill) {
336+
// Create a TextPrefiller with sequential prefill (parallel prefill disabled)
337+
auto prefiller = createTextPrefiller(10, true, false);
338+
339+
// Set up expectations for the text decoder runner
340+
EXPECT_CALL(text_decoder_runner_, step(_, _))
341+
.Times(3) // Once for each token
342+
.WillRepeatedly(Return(Result<executorch::aten::Tensor>(tensor)));
343+
344+
EXPECT_CALL(text_decoder_runner_, logits_to_token(_))
345+
.Times(1) // Only called once at the end
346+
.WillOnce(Return(42));
347+
348+
// Create prompt tokens
349+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
350+
int64_t start_pos = 0;
351+
352+
// Call prefill
353+
auto result = prefiller->prefill(prompt_tokens, start_pos);
354+
355+
// Verify the result
356+
EXPECT_EQ(result.error(), Error::Ok);
357+
EXPECT_EQ(result.get(), 42);
358+
359+
// Verify that start_pos has been updated correctly
360+
EXPECT_EQ(start_pos, prompt_tokens.size());
361+
}

extension/llm/runner/text_prefiller.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,22 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5656
prompt_tokens_to_process.begin());
5757

5858
// Process this chunk
59-
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
59+
auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos);
6060
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6161
cur_token = chunk_result.get();
6262

63+
start_pos += num_tokens_to_prefill_with;
6364
num_tokens_to_process += num_tokens_to_prefill_with;
6465
}
6566

6667
return cur_token;
6768
} else {
6869
// If prompt tokens don't exceed max_seq_len_, process them directly
69-
return prefillChunk(prompt_tokens, start_pos);
70+
return prefill_chunk(prompt_tokens, start_pos);
7071
}
7172
}
7273

73-
::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
74+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
7475
std::vector<uint64_t>& prompt_tokens,
7576
int64_t& start_pos) {
7677
// enable_parallel_prefill_ maybe set even when not using kv cache

extension/llm/runner/text_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ET_EXPERIMENTAL TextPrefiller {
4545
* Module.
4646
* @return The next token of the LLM Module after prefilling this chunk.
4747
*/
48-
::executorch::runtime::Result<uint64_t> prefillChunk(
48+
::executorch::runtime::Result<uint64_t> prefill_chunk(
4949
std::vector<uint64_t>& prompt_tokens,
5050
int64_t& start_pos);
5151

0 commit comments

Comments
 (0)