Skip to content

Commit a10856e

Browse files
matthewchan-gcopybara-github
authored andcommitted
Add support for configuring model response channels in LiteRT-LM Kotlin API ConversationConfig.
LiteRT-LM-PiperOrigin-RevId: 888850632
1 parent 225f7a8 commit a10856e

File tree

5 files changed

+67
-5
lines changed

5 files changed

+67
-5
lines changed

kotlin/java/com/google/ai/edge/litertlm/Benchmark.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ fun benchmark(
7878
null, // SamplerConfig
7979
"[]", // messagesJsonString
8080
"[]", // toolsDescriptionJsonString
81+
null, // channelsJsonString
8182
false, // enableConversationConstrainedDecoding
8283
)
8384

kotlin/java/com/google/ai/edge/litertlm/Config.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@
1515
*/
1616
package com.google.ai.edge.litertlm
1717

18+
import com.google.gson.JsonObject
19+
20+
/**
21+
* Definition of a channel for responses, e.g. thinking channel.
22+
*
23+
* @property channelName The channel name. Text from this channel will be written to
24+
* [Message.channels] with the [channelName] as the key.
25+
* @property start A string that marks the start of the channel.
26+
* @property end A string that marks the end of the channel.
27+
*/
28+
data class Channel(val channelName: String, val start: String, val end: String) {
29+
internal fun toJson() =
30+
JsonObject().apply {
31+
addProperty("channel_name", channelName)
32+
addProperty("start", start)
33+
addProperty("end", end)
34+
}
35+
}
36+
1837
/**
1938
* Backend for the LiteRT-LM engine.
2039
*
@@ -96,6 +115,7 @@ data class ConversationConfig(
96115
val tools: List<ToolProvider> = listOf(),
97116
val samplerConfig: SamplerConfig? = null,
98117
val automaticToolCalling: Boolean = true,
118+
val channels: List<Channel>? = null,
99119
)
100120

101121
/**

kotlin/java/com/google/ai/edge/litertlm/Engine.kt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,27 @@ class Engine(val engineConfig: EngineConfig) : AutoCloseable {
126126
}
127127
}
128128

129+
// Convert the channels to a JSON array, if provided.
130+
// If `channels` is null, the `Conversation` uses the default from the
131+
// `LlmMetadata` or the model type.
132+
// If channels is empty, channels will be disabled.
133+
val channelsJson: JsonArray? =
134+
conversationConfig.channels?.let { channels ->
135+
JsonArray().apply {
136+
for (channel in channels) {
137+
this.add(channel.toJson())
138+
}
139+
}
140+
}
141+
129142
@OptIn(ExperimentalApi::class) // opt-in experimental flags
130143
return Conversation(
131144
LiteRtLmJni.nativeCreateConversation(
132145
handle!!, // Using !! is okay. Checked initialization already.
133146
conversationConfig.samplerConfig,
134147
messagesJson.toString(),
135148
toolManager.getToolsDescription().toString(),
149+
channelsJson?.toString(),
136150
ExperimentalFlags.enableConversationConstrainedDecoding,
137151
),
138152
toolManager,

kotlin/java/com/google/ai/edge/litertlm/LiteRtLmJni.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ internal object LiteRtLmJni {
186186
* @param systemMessageJsonString The system instruction to be used in the conversation.
187187
* @param toolsDescriptionJsonString A json string of a list of tool definitions (Open API json).
188188
* could be used.
189+
* @param channelsJsonString A json string of a list of channel definitions. If null, use the
190+
* default from the model or engine. If empty, channels will be disabled.
189191
* @param enableConversationConstrainedDecoding Whether to enable conversation constrained
190192
* decoding.
191193
* @return A pointer to the native conversation instance.
@@ -195,6 +197,7 @@ internal object LiteRtLmJni {
195197
samplerConfig: SamplerConfig?,
196198
messageJsonString: String,
197199
toolsDescriptionJsonString: String,
200+
channelsJsonString: String?,
198201
enableConversationConstrainedDecoding: Boolean,
199202
): Long
200203

kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ JNI_METHOD(nativeConversationGetBenchmarkInfo)(JNIEnv* env, jclass thiz,
829829
LITERTLM_JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateConversation)(
830830
JNIEnv* env, jclass thiz, jlong engine_pointer, jobject sampler_config_obj,
831831
jstring messages_json_string, jstring tools_description_json_string,
832-
jboolean enable_constrained_decoding) {
832+
jstring channels_json_string, jboolean enable_constrained_decoding) {
833833
Engine* engine = reinterpret_cast<Engine*>(engine_pointer);
834834

835835
// Create a native SessionConfig
@@ -864,13 +864,37 @@ LITERTLM_JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateConversation)(
864864
return 0;
865865
}
866866

867-
// Create the conversation
868-
auto conversation_config =
867+
// Create a ConversationConfig::Builder
868+
auto conversation_config_builder =
869869
ConversationConfig::Builder()
870870
.SetSessionConfig(session_config)
871871
.SetPreface(json_preface)
872-
.SetEnableConstrainedDecoding(enable_constrained_decoding)
873-
.Build(*engine);
872+
.SetEnableConstrainedDecoding(enable_constrained_decoding);
873+
874+
// Set the channels, if provided.
875+
// If channels is nullptr, the Conversation will use the channels defined in
876+
// the LlmMetadata or the default channels for the model type.
877+
// If channels is an empty array, channels will be disabled.
878+
if (channels_json_string != nullptr) {
879+
const char* channels_chars =
880+
env->GetStringUTFChars(channels_json_string, nullptr);
881+
std::string channels_json_str(channels_chars);
882+
env->ReleaseStringUTFChars(channels_json_string, channels_chars);
883+
auto channels_json = nlohmann::ordered_json::parse(channels_json_str);
884+
885+
std::vector<litert::lm::Channel> channels;
886+
if (channels_json.is_array()) {
887+
for (const auto& channel_item : channels_json) {
888+
channels.push_back({channel_item["channel_name"].get<std::string>(),
889+
channel_item["start"].get<std::string>(),
890+
channel_item["end"].get<std::string>()});
891+
}
892+
}
893+
conversation_config_builder.SetChannels(channels);
894+
}
895+
896+
// Build the conversation
897+
auto conversation_config = conversation_config_builder.Build(*engine);
874898

875899
if (!conversation_config.ok()) {
876900
ThrowLiteRtLmJniException(env, "Failed to create conversation config: " +

0 commit comments

Comments
 (0)