Skip to content

[llm] Add generate_from_pos API to LLM runner #11570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ class ET_EXPERIMENTAL IRunner {
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;

/**
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos The starting position in KV cache of the input
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;
/**
* Stop the generation process.
*/
Expand Down
55 changes: 55 additions & 0 deletions extension/llm/runner/test/test_text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,58 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
// Verify is_loaded returns true
EXPECT_TRUE(runner.is_loaded());
}

// Test that generate_from_pos() errors out when max_new_tokens is negative
TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
// Create mock instances using helper functions
auto tokenizer = createMockTokenizer();
auto text_decoder_runner = createMockTextDecoderRunner();
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());

// Set up expectations for the tokenizer encode method
EXPECT_CALL(*tokenizer, encode(_, _, _))
.WillOnce(Return(::tokenizers::Result<std::vector<uint64_t>>(
std::vector<uint64_t>{1, 2, 3})));

// Set up expectations for load methods
EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true));

std::unique_ptr<executorch::llm::Stats> stats =
std::make_unique<executorch::llm::Stats>();
// Create a real TextTokenGenerator
auto text_token_generator = createTextTokenGenerator(
tokenizer.get(), text_decoder_runner.get(), stats.get());

// Create a Runner with our mocked components
TextLLMRunner runner(
{
{"enable_dynamic_shape", false},
{"get_max_seq_len", 10},
{"get_max_context_len", 10},
{"use_kv_cache", true},
},
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
std::make_unique<MockModule>(),
std::move(text_decoder_runner),
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
text_prefiller.release()),
std::move(text_token_generator),
std::move(stats));

// Load
runner.load();

// Set up the generation config with a negative max_new_tokens value
GenerationConfig config;
config.max_new_tokens = 5;
config.echo = false;

// num_prompt_tokens = 3
// max_context_len = 10
// start_pos = 8, this should fail because 10 - 8 > 3, even though
// config.max_new_tokens = 5 > 3, it's still a failure.
Error err = runner.generate_from_pos("test prompt", 8, config);

// Verify that an InvalidArgument error is returned
EXPECT_EQ(err, Error::InvalidArgument);
}
46 changes: 34 additions & 12 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ Error TextLLMRunner::load() {
ET_LOG(Info, format, __VA_ARGS__); \
}

Error TextLLMRunner::generate(
Error TextLLMRunner::generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
Expand Down Expand Up @@ -125,20 +126,34 @@ Error TextLLMRunner::generate(
std::vector<uint64_t> prompt_tokens = encode_res.get();
int num_prompt_tokens = prompt_tokens.size();

// Reduce max_context_len by start_pos
int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos;
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < metadata_.at(kMaxContextLen),
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
num_prompt_tokens < max_context_len,
"num_prompt_tokens %d >= max_context_len %" PRId64
", Max seq length exceeded - please increase max seq len value in your export script",
num_prompt_tokens,
metadata_.at(kMaxContextLen));

// Determine max_new_tokens using the GenerationConfig's resolve method
int max_new_tokens = config.resolve_max_new_tokens(
metadata_.at(kMaxContextLen), num_prompt_tokens);

ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens);

max_context_len);

// Determine max_new_tokens using the GenerationConfig's resolve method,
// then subtract start_pos for max_new_tokens.
int max_new_tokens =
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);

ET_LOG(
Info,
"Max new tokens resolved: %d, given start_pos %" PRId64
", num_prompt_tokens %zu, max_context_len %" PRId64,
max_new_tokens,
start_pos,
prompt_tokens.size(),
max_context_len);
ET_CHECK_OR_RETURN_ERROR(
max_new_tokens > 0,
InvalidArgument,
"Max new tokens %d is less than or equal to 0",
max_new_tokens);
// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
Expand All @@ -147,7 +162,7 @@ Error TextLLMRunner::generate(
if (config.echo) {
wrapped_callback(prompt);
}
int64_t pos = 0;
int64_t pos = start_pos;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
Expand Down Expand Up @@ -201,6 +216,13 @@ Error TextLLMRunner::generate(

return Error::Ok;
}
Error TextLLMRunner::generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
// Create a GenerationConfig for warmup
Expand Down
30 changes: 28 additions & 2 deletions extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
* @brief Generates text based on the provided prompt
*
* This method performs text generation using the loaded model. It processes
* the input prompt, runs the model in prefill and decode phases, and returns
* generated text through callbacks.
* the input prompt, runs the model in prefill and decode phases until max
* tokens to generate is reached or eos token is generated, then returns
* generated text and perf stats through callbacks.
*
* @param prompt The input text to generate from
* @param config Configuration parameters for text generation (e.g.,
Expand All @@ -94,6 +95,31 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* @brief Generates text based on the provided prompt and start position
*
* This method performs text generation using the loaded model. It processes
* the input prompt, runs the model in prefill and decode phases using the
* start position until max tokens to generate is reached or eos token is
* generated, then returns generated text and perf stats through callbacks.
*
* @param prompt The input text to generate from
* @param start_pos The starting position in KV cache of the input
* @param config Configuration parameters for text generation (e.g.,
* max_new_tokens, temperature)
* @param token_callback Function called for each generated token with the
* decoded text
* @param stats_callback Function called with performance statistics
* @return ::executorch::runtime::Error Success or error status
*/
::executorch::runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* @brief Warms up the model with a sample prompt
*
Expand Down
Loading