Skip to content

Commit 4119bed

Browse files
matthewchan-gcopybara-github
authored andcommitted
Change Message from variant to a class that derives from nlohmann::ordered_json.
LiteRT-LM-PiperOrigin-RevId: 893152404
1 parent ffdb0fc commit 4119bed

24 files changed

+340
-498
lines changed

c/engine.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,12 @@ CreateConversationCallback(LiteRtLmStreamCallback callback, void* user_data) {
7575
callback(user_data, nullptr, true, const_cast<char*>(error_str.c_str()));
7676
return;
7777
}
78-
if (auto* json_msg = std::get_if<litert::lm::JsonMessage>(&*message)) {
79-
if (json_msg->is_null()) { // End of stream marker
80-
callback(user_data, nullptr, true, nullptr);
81-
} else {
82-
std::string json_str = json_msg->dump();
83-
callback(user_data, const_cast<char*>(json_str.c_str()), false,
84-
nullptr);
85-
}
78+
if (message->empty()) { // End of stream marker
79+
callback(user_data, nullptr, true, nullptr);
8680
} else {
87-
std::string error_str = "Unsupported message type";
88-
callback(user_data, nullptr, true, const_cast<char*>(error_str.c_str()));
81+
std::string json_str = message->dump();
82+
callback(user_data, const_cast<char*>(json_str.c_str()), false,
83+
nullptr);
8984
}
9085
};
9186
}
@@ -110,7 +105,7 @@ using ::litert::lm::Engine;
110105
using ::litert::lm::EngineFactory;
111106
using ::litert::lm::EngineSettings;
112107
using ::litert::lm::InputText;
113-
using ::litert::lm::JsonMessage;
108+
114109
using ::litert::lm::Message;
115110
using ::litert::lm::ModelAssets;
116111
using ::litert::lm::Responses;
@@ -712,13 +707,8 @@ LiteRtLmJsonResponse* litert_lm_conversation_send_message(
712707
ABSL_LOG(ERROR) << "Failed to send message: " << response.status();
713708
return nullptr;
714709
}
715-
auto* json_response = std::get_if<JsonMessage>(&*response);
716-
if (!json_response) {
717-
ABSL_LOG(ERROR) << "Response is not a JSON message.";
718-
return nullptr;
719-
}
720710
auto* c_response = new LiteRtLmJsonResponse;
721-
c_response->json_string = json_response->dump();
711+
c_response->json_string = response->dump();
722712
return c_response;
723713
}
724714

docs/api/cpp/conversation.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ CHECK_OK(conversation);
5858

5959
// 4. Send message to the LLM with blocking call.
6060
absl::StatusOr<Message> model_message = (*conversation)->SendMessage(
61-
JsonMessage{
61+
Message{
6262
{"role", "user"},
6363
{"content", "What is the tallest building in the world?"}
6464
});
@@ -72,7 +72,7 @@ std::cout << *model_message << std::endl;
7272
// process the message once a chunk of message output is received.
7373
std::stringstream captured_output;
7474
(*conversation)->SendMessageAsync(
75-
JsonMessage{
75+
Message{
7676
{"role", "user"},
7777
{"content", "What is the tallest building in the world?"}
7878
},
@@ -97,18 +97,16 @@ absl::AnyInvocable<void(absl::StatusOr<Message>)> CreatePrintMessageCallback(
9797
std::cout << message.status().message() << std::endl;
9898
return;
9999
}
100-
if (auto json_message = std::get_if<JsonMessage>(&(*message))) {
101-
if (json_message->is_null()) {
102-
std::cout << std::endl << std::flush;
103-
return;
104-
}
105-
ABSL_CHECK_OK(PrintJsonMessage(*json_message, captured_output,
106-
/*streaming=*/true));
100+
if (message->empty()) {
101+
std::cout << std::endl << std::flush;
102+
return;
107103
}
104+
ABSL_CHECK_OK(PrintMessage(*message, captured_output,
105+
/*streaming=*/true));
108106
};
109107
}
110108
111-
absl::Status PrintJsonMessage(const JsonMessage& message,
109+
absl::Status PrintMessage(const Message& message,
112110
std::stringstream& captured_output,
113111
bool streaming = false) {
114112
if (message["content"].is_array()) {
@@ -162,7 +160,7 @@ auto engine_settings = EngineSettings::CreateDefault(
162160

163161
// Send message to the LLM with image data.
164162
absl::StatusOr<Message> model_message = (*conversation)->SendMessage(
165-
JsonMessage{
163+
Message{
166164
{"role", "user"},
167165
{"content", { // Now content must be an array.
168166
{{"type", "text"}, {"text", "Describe the following image: "}},
@@ -176,7 +174,7 @@ std::cout << *model_message << std::endl;
176174

177175
// Send message to the LLM with audio data.
178176
model_message = (*conversation)->SendMessage(
179-
JsonMessage{
177+
Message{
180178
{"role", "user"},
181179
{"content", { // Now content must be an array.
182180
{{"type", "text"}, {"text", "Transcribe the audio: "}},
@@ -190,7 +188,7 @@ std::cout << *model_message << std::endl;
190188

191189
// The content can include multiple image or audio data.
192190
model_message = (*conversation)->SendMessage(
193-
JsonMessage{
191+
Message{
194192
{"role", "user"},
195193
{"content", { // Now content must be an array.
196194
{{"type", "text"}, {"text", "First briefly describe the two images "}},
@@ -223,7 +221,7 @@ data to Session.
223221
224222
The core input and output format for the Conversation API is
225223
[`Message`][Message]. Currently, this is implemented as
226-
[`JsonMessage`][JsonMessage], which is a type alias for
224+
[`Message`][Message], which is a type alias for
227225
[`ordered_json`][ordered_json], a flexible nested key-value data structure.
228226
229227
The [`Conversation`][Conversation] API operates on a message-in-message-out
@@ -492,7 +490,7 @@ This function is triggered under the following conditions:
492490
* When a new chunk of the [`Message`][Message] is received from the Model.
493491
* If an error occurs during LiteRT-LM's message processing.
494492
* Upon completion of the LLM's inference, the callback is triggered with an
495-
empty [`Message`][Message] (e.g., `JsonMessage()`) to signal the end of the
493+
empty [`Message`][Message] (e.g., `Message()`) to signal the end of the
496494
response.
497495

498496
Refer to the [Step 6 asynchronous call](#text-only-content) for an example
@@ -584,7 +582,7 @@ the asynchronous call is complete.
584582
[Jinja]: https://jinja.palletsprojects.com/en/stable/ "jinja prompt template"
585583
[PromptTemplate]: https://github.com/google-ai-edge/LiteRT-LM/blob/main/runtime/components/prompt_template.h "litert::lm::PromptTemplate"
586584
[message]: https://github.com/google-ai-edge/LiteRT-LM/blob/63f7dec93ac85560e64194a00b5d7c407de40846/runtime/conversation/io_types.h#L28 "litert::lm::Message"
587-
[jsonmessage]: https://github.com/google-ai-edge/LiteRT-LM/blob/63f7dec93ac85560e64194a00b5d7c407de40846/runtime/conversation/io_types.h#L25 "litert::lm::JsonMessage"
585+
588586
[ordered_json]: https://json.nlohmann.me/api/ordered_json/ "ordered_json"
589587
[preface]: https://github.com/google-ai-edge/LiteRT-LM/blob/63f7dec93ac85560e64194a00b5d7c407de40846/runtime/conversation/io_types.h#L48 "litert::lm::Preface"
590588
[ConversationConfig]: https://github.com/google-ai-edge/LiteRT-LM/blob/63f7dec93ac85560e64194a00b5d7c407de40846/runtime/conversation/conversation.h#L44 "litert::lm::ConversationConfig"

docs/api/cpp/tool-use.md

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ Example:
143143
144144
```c++
145145
// Construct the user message as a JSON object.
146-
JsonMessage user_message = JsonMessage::parse(R"({
146+
Message user_message = Message::parse(R"({
147147
"role": "user",
148148
"content": {
149149
"type": "text",
@@ -234,7 +234,7 @@ know the result. Pass the tool result as a message with the `role` set to
234234

235235
```c++
236236
// Construct the tool message containing the result.
237-
JsonMessage tool_message = {{"role", "tool"}, {"content", weather_report}};
237+
Message tool_message = {{"role", "tool"}, {"content", weather_report}};
238238

239239
// Send the tool message to the model.
240240
ASSIGN_OR_RETURN(model_message, conversation->SendMessage(tool_message));
@@ -362,7 +362,7 @@ while (true) {
362362
}
363363
364364
// Construct the user message.
365-
JsonMessage input_message = {
365+
Message input_message = {
366366
{"role", "user"},
367367
{"content", {{{"type", "text"}, {"text", input_prompt}}}}};
368368
@@ -372,41 +372,35 @@ while (true) {
372372
ASSIGN_OR_RETURN(Message message,
373373
conversation->SendMessage(input_message));
374374
375-
// Get the JSON message from the model's response.
376-
if (std::holds_alternative<json>(message)) {
377-
JsonMessage message_json =
378-
std::get<nlohmann::ordered_json>(message);
379-
380-
// Check for tool calls.
381-
if (message_json.contains("tool_calls") &&
382-
message_json["tool_calls"].is_array() &&
383-
!message_json["tool_calls"].empty()) {
384-
// This JSON array will hold the tool response messages.
385-
nlohmann::ordered_json tool_messages = nlohmann::ordered_json::array();
386-
387-
// For each tool call, call the tool and add the response.
388-
for (const auto& tool_call : message_json["tool_calls"]) {
389-
JsonMessage tool_message = {{"role", "tool"},
390-
{"content", {}}};
391-
const nlohmann::ordered_json& function = tool_call["function"];
392-
tool_message["content"] =
393-
tools.CallTool(function["name"], function["arguments"]);
394-
tool_messages.push_back(tool_message);
395-
}
375+
// Check for tool calls.
376+
if (message.contains("tool_calls") &&
377+
message["tool_calls"].is_array() &&
378+
!message["tool_calls"].empty()) {
379+
// This JSON array will hold the tool response messages.
380+
nlohmann::ordered_json tool_messages = nlohmann::ordered_json::array();
381+
382+
// For each tool call, call the tool and add the response.
383+
for (const auto& tool_call : message["tool_calls"]) {
384+
Message tool_message = {{"role", "tool"},
385+
{"content", {}}};
386+
const nlohmann::ordered_json& function = tool_call["function"];
387+
tool_message["content"] =
388+
tools.CallTool(function["name"], function["arguments"]);
389+
tool_messages.push_back(tool_message);
390+
}
396391
397-
// The next input message is the tool response.
398-
input_message = tool_messages;
399-
} else {
400-
// If there are no tool calls, print the model's response and exit the
401-
// tool calling loop.
402-
for (const auto& item : message_json["content"]) {
403-
if (item.contains("type") && item["type"] == "text") {
404-
std::cout << item["text"].get<std::string>() << std::endl;
405-
}
392+
// The next input message is the tool response.
393+
input_message = tool_messages;
394+
} else {
395+
// If there are no tool calls, print the model's response and exit the
396+
// tool calling loop.
397+
for (const auto& item : message["content"]) {
398+
if (item.contains("type") && item["type"] == "text") {
399+
std::cout << item["text"].get<std::string>() << std::endl;
406400
}
407-
408-
break;
409401
}
402+
403+
break;
410404
}
411405
}
412406
}
@@ -449,35 +443,28 @@ while (true) {
449443
return;
450444
}
451445

452-
if (!std::holds_alternative<nlohmann::json>(*message)) {
453-
return;
454-
}
455-
456-
// Get JSON from the message.
457-
JsonMessage message_json = std::get<JsonMessage>(*message);
458-
459446
// An empty message indicates the model is done generating.
460-
if (message_json.is_null()) {
447+
if (message->empty()) {
461448
std::cout << std::endl << std::flush;
462449
done.Notify();
463450
return;
464451
}
465452

466453
// Print any text content.
467-
if (message_json.contains("content") &&
468-
message_json["content"].is_array()) {
469-
for (const auto& item : message_json["content"]) {
454+
if (message->contains("content") &&
455+
(*message)["content"].is_array()) {
456+
for (const auto& item : (*message)["content"]) {
470457
if (item.contains("text")) {
471458
std::cout << item["text"] << std::endl << std::flush;
472459
}
473460
}
474461
}
475462

476463
// Collect any tool calls, if present.
477-
if (message_json.contains("tool_calls") &&
478-
message_json["tool_calls"].is_array() &&
479-
!message_json["tool_calls"].empty()) {
480-
for (const auto& tool_call : message_json["tool_calls"]) {
464+
if (message->contains("tool_calls") &&
465+
(*message)["tool_calls"].is_array() &&
466+
!(*message)["tool_calls"].empty()) {
467+
for (const auto& tool_call : (*message)["tool_calls"]) {
481468
tool_calls.push_back(tool_call);
482469
}
483470
}

kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ using litert::lm::InputAudio;
6969
using litert::lm::InputData;
7070
using litert::lm::InputImage;
7171
using litert::lm::InputText;
72-
using litert::lm::JsonMessage;
72+
7373
using litert::lm::JsonPreface;
7474
using litert::lm::Message;
7575
using litert::lm::ModelAssets;
@@ -939,8 +939,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
939939
reinterpret_cast<Conversation*>(conversation_pointer);
940940

941941
const char* json_chars = env->GetStringUTFChars(messageJSONString, nullptr);
942-
litert::lm::JsonMessage json_message =
943-
nlohmann::ordered_json::parse(json_chars);
942+
litert::lm::Message message =
943+
Message(nlohmann::ordered_json::parse(json_chars));
944944
env->ReleaseStringUTFChars(messageJSONString, json_chars);
945945

946946
litert::lm::OptionalArgs optional_args;
@@ -996,27 +996,16 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
996996
};
997997

998998
if (message.ok()) {
999-
if (!std::holds_alternative<litert::lm::JsonMessage>(*message)) {
1000-
ABSL_LOG(WARNING) << "Receive callback OnError: Not a JsonMessage";
1001-
jstring err_message =
1002-
env->NewStringUTF("Response is not a JsonMessage");
1003-
env->CallVoidMethod(callback_global, on_error_mid,
1004-
(jint)absl::StatusCode::kInternal, err_message);
1005-
env->DeleteLocalRef(err_message);
999+
if (message->empty()) {
1000+
// Null/empty message indicates completion.
1001+
env->CallVoidMethod(callback_global, on_complete_mid);
10061002
on_done_fn();
10071003
} else {
1008-
auto json_message = std::get<litert::lm::JsonMessage>(*message);
1009-
if (json_message.is_null()) {
1010-
// Null message indicates completion.
1011-
env->CallVoidMethod(callback_global, on_complete_mid);
1012-
on_done_fn();
1013-
} else {
1014-
std::string message_str = json_message.dump();
1015-
jstring message_jstr = NewStringStandardUTF(env, message_str);
1016-
env->CallVoidMethod(callback_global, on_message_mid,
1017-
message_jstr);
1018-
env->DeleteLocalRef(message_jstr);
1019-
}
1004+
std::string message_str = message->dump();
1005+
jstring message_jstr = NewStringStandardUTF(env, message_str);
1006+
env->CallVoidMethod(callback_global, on_message_mid,
1007+
message_jstr);
1008+
env->DeleteLocalRef(message_jstr);
10201009
}
10211010
} else {
10221011
ABSL_LOG(WARNING) << "Receive callback OnError: " << message.status();
@@ -1033,8 +1022,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
10331022
}
10341023
};
10351024

1036-
auto status = conversation->SendMessageAsync(
1037-
json_message, std::move(callback_fn), std::move(optional_args));
1025+
auto status = conversation->SendMessageAsync(message, std::move(callback_fn),
1026+
std::move(optional_args));
10381027

10391028
if (!status.ok()) {
10401029
ThrowLiteRtLmJniException(
@@ -1049,8 +1038,8 @@ LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
10491038
reinterpret_cast<Conversation*>(conversation_pointer);
10501039

10511040
const char* json_chars = env->GetStringUTFChars(messageJSONString, nullptr);
1052-
litert::lm::JsonMessage json_message =
1053-
nlohmann::ordered_json::parse(json_chars);
1041+
litert::lm::Message message =
1042+
Message(nlohmann::ordered_json::parse(json_chars));
10541043
env->ReleaseStringUTFChars(messageJSONString, json_chars);
10551044

10561045
litert::lm::OptionalArgs optional_args;
@@ -1060,22 +1049,14 @@ LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
10601049
optional_args.extra_context = extra_context;
10611050
}
10621051

1063-
auto response =
1064-
conversation->SendMessage(json_message, std::move(optional_args));
1052+
auto response = conversation->SendMessage(message, std::move(optional_args));
10651053
if (!response.ok()) {
10661054
ThrowLiteRtLmJniException(env, "Failed to call nativeSendMessage: " +
10671055
response.status().ToString());
10681056
return nullptr;
10691057
}
10701058

1071-
if (!std::holds_alternative<litert::lm::JsonMessage>(*response)) {
1072-
ThrowLiteRtLmJniException(
1073-
env, "Failed to call nativeSendMessage: Response is not a JsonMessage");
1074-
return nullptr;
1075-
}
1076-
1077-
auto json_response = std::get<litert::lm::JsonMessage>(*response);
1078-
return NewStringStandardUTF(env, json_response.dump());
1059+
return NewStringStandardUTF(env, response->dump());
10791060
}
10801061

10811062
LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeConversationCancelProcess)(

0 commit comments

Comments
 (0)