Skip to content

Commit 4a19a45

Browse files
hheydarycopybara-github
authored andcommitted
Plumb max_output_tokens from settings to Conversation::SendMessage.
LiteRT-LM-PiperOrigin-RevId: 894230885
1 parent cbf9cd1 commit 4a19a45

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

runtime/engine/litert_lm_lib.cc

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,23 @@ absl::Status RunSingleTurnConversation(const std::string& input_prompt,
448448
json content_list = json::array();
449449
RETURN_IF_ERROR(BuildContentList(input_prompt, content_list, settings));
450450
std::stringstream captured_output;
451+
OptionalArgs optional_args;
452+
if (settings.max_output_tokens > 0) {
453+
optional_args.max_output_tokens = settings.max_output_tokens;
454+
}
455+
451456
if (settings.async) {
452457
RETURN_IF_ERROR(conversation->SendMessageAsync(
453458
json::object({{"role", "user"}, {"content", content_list}}),
454-
CreatePrintMessageCallback(captured_output, settings.benchmark)));
459+
CreatePrintMessageCallback(captured_output, settings.benchmark),
460+
std::move(optional_args)));
455461
RETURN_IF_ERROR(engine->WaitUntilDone(kWaitUntilDoneTimeout));
456462
} else {
457-
ASSIGN_OR_RETURN(auto model_message,
458-
conversation->SendMessage(json::object(
459-
{{"role", "user"}, {"content", content_list}})));
463+
ASSIGN_OR_RETURN(
464+
auto model_message,
465+
conversation->SendMessage(
466+
json::object({{"role", "user"}, {"content", content_list}}),
467+
std::move(optional_args)));
460468
RETURN_IF_ERROR(PrintJsonMessage(std::get<JsonMessage>(model_message),
461469
captured_output));
462470
}
@@ -487,15 +495,23 @@ absl::Status RunMultiTurnConversation(const LiteRtLmSettings& settings,
487495
if (content_list.empty()) {
488496
continue;
489497
}
498+
OptionalArgs optional_args;
499+
if (settings.max_output_tokens > 0) {
500+
optional_args.max_output_tokens = settings.max_output_tokens;
501+
}
502+
490503
if (settings.async) {
491504
RETURN_IF_ERROR(conversation->SendMessageAsync(
492505
json::object({{"role", "user"}, {"content", content_list}}),
493-
CreatePrintMessageCallback(captured_output, settings.benchmark)));
506+
CreatePrintMessageCallback(captured_output, settings.benchmark),
507+
std::move(optional_args)));
494508
RETURN_IF_ERROR(engine->WaitUntilDone(kWaitUntilDoneTimeout));
495509
} else {
496-
ASSIGN_OR_RETURN(auto model_message,
497-
conversation->SendMessage(json::object(
498-
{{"role", "user"}, {"content", content_list}})));
510+
ASSIGN_OR_RETURN(
511+
auto model_message,
512+
conversation->SendMessage(
513+
json::object({{"role", "user"}, {"content", content_list}}),
514+
std::move(optional_args)));
499515
RETURN_IF_ERROR(PrintJsonMessage(std::get<JsonMessage>(model_message),
500516
captured_output));
501517
}

0 commit comments

Comments
 (0)