Skip to content

Commit 9739359

Browse files
matthewchan-gcopybara-github
authored andcommitted
Add extraContext to Conversation sendMessage methods.
LiteRT-LM-PiperOrigin-RevId: 879915886
1 parent 28f1827 commit 9739359

13 files changed

Lines changed: 334 additions & 39 deletions

File tree

docs/api/kotlin/getting_started.md

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,15 @@ engine.createConversation(conversationConfig).use { conversation ->
166166

167167
There are three ways to send messages:
168168

169-
- **`sendMessage(contents): Message`**: Synchronous call that blocks until the
170-
model returns a complete response. This is simpler for basic
171-
request/response interactions.
172-
- **`sendMessageAsync(contents, callback)`**: Asynchronous call for streaming
173-
responses. This is better for long-running requests or when you want to
174-
display the response as it's being generated.
175-
- **`sendMessageAsync(contents): Flow<Message>`**: Asynchronous call that
176-
returns a Kotlin Flow for streaming responses. This is the recommended
177-
approach for Coroutine users.
169+
- **`sendMessage(contents, extraContext): Message`**: Synchronous call that
170+
blocks until the model returns a complete response. This is simpler for
171+
basic request/response interactions.
172+
- **`sendMessageAsync(contents, callback, extraContext)`**: Asynchronous call
173+
for streaming responses. This is better for long-running requests or when
174+
you want to display the response as it's being generated.
175+
- **`sendMessageAsync(contents, extraContext): Flow<Message>`**: Asynchronous
176+
call that returns a Kotlin Flow for streaming responses. This is the
177+
recommended approach for Coroutine users.
178178

179179
**Synchronous Example:**
180180

@@ -456,6 +456,33 @@ To try out tool use, clone the repo and run with
456456
bazel run -c opt //kotlin/java/com/google/ai/edge/litertlm/example:tool -- <abs_model_path>
457457
```
458458

459+
### 7. Extra Template Context Variables
460+
461+
You can pass extra context variables to the prompt template for rendering.
462+
This allows you to customize the model's behavior based on dynamic values.
463+
464+
`extraContext` is an optional `Map<String, Any>` that can be passed to
465+
`sendMessage` and `sendMessageAsync`. These variables are merged with the extra
466+
context provided in the `Preface` (if any), with keys in the message-level
467+
context overwriting those in the `Preface`.
468+
469+
```kotlin
470+
val extraContext = mapOf(
471+
"user_name" to "Alice",
472+
"enable_thinking" to true
473+
)
474+
475+
// Synchronous
476+
val response = conversation.sendMessage("Hello!", extraContext = extraContext)
477+
478+
// Asynchronous with Flow
479+
conversation.sendMessageAsync("Hello!", extraContext = extraContext)
480+
.collect { ... }
481+
```
482+
483+
These variables are used within the Jinja-style prompt templates, e.g.,
484+
`{{ user_name }}` or `{% if enable_thinking %}`.
485+
459486
## Error Handling
460487

461488
API methods can throw `LiteRtLmJniException` for errors from the native layer or

kotlin/java/com/google/ai/edge/litertlm/Conversation.kt

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,21 @@ class Conversation(
8686
* [RECURRING_TOOL_CALL_LIMIT] times.
8787
*
8888
* @param message The message to send to the model.
89+
* @param extraContext Optional context used for prompt template rendering.
8990
* @return The model's response message.
9091
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
9192
* invalid response, or if the tool call limit is exceeded.
9293
* @throws LiteRtLmJniException if an error occurs during the native call.
9394
*/
94-
fun sendMessage(message: Message): Message {
95+
fun sendMessage(message: Message, extraContext: Map<String, Any> = emptyMap()): Message {
9596
checkIsAlive()
9697

9798
var currentMessageJson = message.toJson()
99+
var extraContextJsonString = extraContext.toJsonObject().toString()
98100

99101
for (i in 0..<RECURRING_TOOL_CALL_LIMIT) {
100-
val responseJsonString = LiteRtLmJni.nativeSendMessage(handle, currentMessageJson.toString())
102+
val responseJsonString =
103+
LiteRtLmJni.nativeSendMessage(handle, currentMessageJson.toString(), extraContextJsonString)
101104
val responseJsonObject = JsonParser.parseString(responseJsonString).asJsonObject
102105

103106
if (responseJsonObject.has("tool_calls")) {
@@ -124,13 +127,14 @@ class Conversation(
124127
* [RECURRING_TOOL_CALL_LIMIT] times.
125128
*
126129
* @param contents The list of contents to send to the model.
130+
* @param extraContext Optional context used for prompt template rendering.
127131
* @return The model's response message.
128132
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
129133
* invalid response, or if the tool call limit is exceeded.
130134
* @throws LiteRtLmJniException if an error occurs during the native call.
131135
*/
132-
fun sendMessage(contents: Contents): Message {
133-
return sendMessage(Message.user(contents))
136+
fun sendMessage(contents: Contents, extraContext: Map<String, Any> = emptyMap()): Message {
137+
return sendMessage(Message.user(contents), extraContext)
134138
}
135139

136140
/**
@@ -142,12 +146,14 @@ class Conversation(
142146
* [RECURRING_TOOL_CALL_LIMIT] times.
143147
*
144148
* @param text The text to send to the model.
149+
* @param extraContext Optional context used for prompt template rendering.
145150
* @return The model's response message.
146151
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
147152
* invalid response, or if the tool call limit is exceeded.
148153
* @throws LiteRtLmJniException if an error occurs during the native call.
149154
*/
150-
fun sendMessage(text: String): Message = sendMessage(Contents.of(text))
155+
fun sendMessage(text: String, extraContext: Map<String, Any> = emptyMap()): Message =
156+
sendMessage(Contents.of(text), extraContext)
151157

152158
/**
153159
* Send a message to the model and returns the response async with a callback.
@@ -159,14 +165,26 @@ class Conversation(
159165
*
160166
* @param message The message to send to the model.
161167
* @param callback The callback to receive the streaming responses.
168+
* @param extraContext Optional context used for prompt template rendering.
162169
* @throws IllegalStateException if the conversation has already been closed or the content is
163170
* empty.
164171
*/
165-
fun sendMessageAsync(message: Message, callback: MessageCallback) {
172+
fun sendMessageAsync(
173+
message: Message,
174+
callback: MessageCallback,
175+
extraContext: Map<String, Any> = emptyMap(),
176+
) {
166177
checkIsAlive()
167178

179+
val extraContextJsonString = extraContext.toJsonObject().toString()
180+
168181
val jniCallback = JniMessageCallbackImpl(callback)
169-
LiteRtLmJni.nativeSendMessageAsync(handle, message.toJson().toString(), jniCallback)
182+
LiteRtLmJni.nativeSendMessageAsync(
183+
handle,
184+
message.toJson().toString(),
185+
extraContextJsonString,
186+
jniCallback,
187+
)
170188
}
171189

172190
/**
@@ -179,11 +197,15 @@ class Conversation(
179197
*
180198
* @param contents The list of contents to send to the model.
181199
* @param callback The callback to receive the streaming responses.
200+
* @param extraContext Optional context used for prompt template rendering.
182201
* @throws IllegalStateException if the conversation has already been closed or the content is
183202
* empty.
184203
*/
185-
fun sendMessageAsync(contents: Contents, callback: MessageCallback) =
186-
sendMessageAsync(Message.user(contents), callback)
204+
fun sendMessageAsync(
205+
contents: Contents,
206+
callback: MessageCallback,
207+
extraContext: Map<String, Any> = emptyMap(),
208+
) = sendMessageAsync(Message.user(contents), callback, extraContext)
187209

188210
/**
189211
* Send a text to the model and returns the response async with a callback.
@@ -195,11 +217,15 @@ class Conversation(
195217
*
196218
* @param text The text to send to the model.
197219
* @param callback The callback to receive the streaming responses.
220+
* @param extraContext Optional context used for prompt template rendering.
198221
* @throws IllegalStateException if the conversation has already been closed or the content is
199222
* empty.
200223
*/
201-
fun sendMessageAsync(text: String, callback: MessageCallback) =
202-
sendMessageAsync(Contents.of(text), callback)
224+
fun sendMessageAsync(
225+
text: String,
226+
callback: MessageCallback,
227+
extraContext: Map<String, Any> = emptyMap(),
228+
) = sendMessageAsync(Contents.of(text), callback, extraContext)
203229

204230
/**
205231
* Sends a message to the model and returns the response async as a [Flow].
@@ -210,11 +236,15 @@ class Conversation(
210236
* [RECURRING_TOOL_CALL_LIMIT] times.
211237
*
212238
* @param message The message to send to the model.
239+
* @param extraContext Optional context used for prompt template rendering.
213240
* @return A Flow of messages representing the model's response.
214241
* @throws IllegalStateException if the conversation has already been closed or the content is
215242
* empty.
216243
*/
217-
fun sendMessageAsync(message: Message): Flow<Message> = callbackFlow {
244+
fun sendMessageAsync(
245+
message: Message,
246+
extraContext: Map<String, Any> = emptyMap(),
247+
): Flow<Message> = callbackFlow {
218248
sendMessageAsync(
219249
message,
220250
object : MessageCallback {
@@ -230,6 +260,7 @@ class Conversation(
230260
close(throwable)
231261
}
232262
},
263+
extraContext,
233264
)
234265
awaitClose {}
235266
}
@@ -243,11 +274,15 @@ class Conversation(
243274
* [RECURRING_TOOL_CALL_LIMIT] times.
244275
*
245276
* @param contents The list of contents to send to the model.
277+
* @param extraContext Optional context used for prompt template rendering.
246278
* @return A Flow of messages representing the model's response.
247279
* @throws IllegalStateException if the conversation has already been closed or the content is
248280
* empty.
249281
*/
250-
fun sendMessageAsync(contents: Contents): Flow<Message> = sendMessageAsync(Message.user(contents))
282+
fun sendMessageAsync(
283+
contents: Contents,
284+
extraContext: Map<String, Any> = emptyMap(),
285+
): Flow<Message> = sendMessageAsync(Message.user(contents), extraContext)
251286

252287
/**
253288
* Sends a text to the model and returns the response async as a [Flow].
@@ -258,11 +293,13 @@ class Conversation(
258293
* [RECURRING_TOOL_CALL_LIMIT] times.
259294
*
260295
* @param text The text to send to the model.
296+
* @param extraContext Optional context used for prompt template rendering.
261297
* @return A Flow of messages representing the model's response.
262298
* @throws IllegalStateException if the conversation has already been closed or the content is
263299
* empty.
264300
*/
265-
fun sendMessageAsync(text: String): Flow<Message> = sendMessageAsync(Contents.of(text))
301+
fun sendMessageAsync(text: String, extraContext: Map<String, Any> = emptyMap()): Flow<Message> =
302+
sendMessageAsync(Contents.of(text), extraContext)
266303

267304
private fun handleToolCalls(toolCallsJsonObject: JsonObject): JsonObject {
268305
val toolCallsJSONArray = toolCallsJsonObject.getAsJsonArray("tool_calls")
@@ -328,6 +365,7 @@ class Conversation(
328365
LiteRtLmJni.nativeSendMessageAsync(
329366
handle,
330367
localToolResponse.toString(),
368+
"{}",
331369
this@JniMessageCallbackImpl,
332370
)
333371
pendingToolResponseJSONMessage = null // Clear after sending

kotlin/java/com/google/ai/edge/litertlm/LiteRtLmJni.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ internal object LiteRtLmJni {
215215
external fun nativeSendMessageAsync(
216216
conversationPointer: Long,
217217
messageJsonString: String,
218+
extraContextJsonString: String,
218219
callback: JniMessageCallback,
219220
)
220221

@@ -225,7 +226,11 @@ internal object LiteRtLmJni {
225226
* @param messageJsonString The message to be processed by the native conversation instance.
226227
* @return The response message in JSON string format.
227228
*/
228-
external fun nativeSendMessage(conversationPointer: Long, messageJsonString: String): String
229+
external fun nativeSendMessage(
230+
conversationPointer: Long,
231+
messageJsonString: String,
232+
extraContextJsonString: String,
233+
): String
229234

230235
/**
231236
* Cancels the ongoing conversation process.

kotlin/java/com/google/ai/edge/litertlm/jni/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ cc_binary(
2525
linkshared = 1,
2626
deps = [
2727
"@com_google_absl//absl/base:log_severity",
28+
"@com_google_absl//absl/container:flat_hash_map",
2829
"@com_google_absl//absl/functional:any_invocable",
2930
"@com_google_absl//absl/log:absl_log",
3031
"@com_google_absl//absl/log:globals",

kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,19 @@ SamplerParameters CreateSamplerParamsFromJni(JNIEnv* env,
295295

296296
return sampler_params;
297297
}
298+
299+
nlohmann::ordered_json GetExtraContextJson(JNIEnv* env,
300+
jstring extra_context_json_string) {
301+
const char* extra_context_chars =
302+
env->GetStringUTFChars(extra_context_json_string, nullptr);
303+
nlohmann::ordered_json extra_context_json;
304+
if (extra_context_chars != nullptr) {
305+
extra_context_json = nlohmann::ordered_json::parse(extra_context_chars);
306+
}
307+
env->ReleaseStringUTFChars(extra_context_json_string, extra_context_chars);
308+
return extra_context_json;
309+
}
310+
298311
} // namespace
299312

300313
extern "C" {
@@ -851,7 +864,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteConversation)(
851864

852865
LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
853866
JNIEnv* env, jclass thiz, jlong conversation_pointer,
854-
jstring messageJSONString, jobject callback) {
867+
jstring messageJSONString, jstring extraContextJsonString,
868+
jobject callback) {
855869
JavaVM* jvm = nullptr;
856870
if (env->GetJavaVM(&jvm) != JNI_OK) {
857871
ThrowLiteRtLmJniException(env, "Failed to get JavaVM");
@@ -866,6 +880,13 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
866880
nlohmann::ordered_json::parse(json_chars);
867881
env->ReleaseStringUTFChars(messageJSONString, json_chars);
868882

883+
litert::lm::OptionalArgs optional_args;
884+
nlohmann::ordered_json extra_context =
885+
GetExtraContextJson(env, extraContextJsonString);
886+
if (!extra_context.is_null() && !extra_context.empty()) {
887+
optional_args.extra_context = extra_context;
888+
}
889+
869890
jobject callback_global = env->NewGlobalRef(callback);
870891
jclass callback_class = env->GetObjectClass(callback_global);
871892
jmethodID on_message_mid =
@@ -932,8 +953,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
932953
}
933954
};
934955

935-
auto status =
936-
conversation->SendMessageAsync(json_message, std::move(callback_fn));
956+
auto status = conversation->SendMessageAsync(
957+
json_message, std::move(callback_fn), std::move(optional_args));
937958

938959
if (!status.ok()) {
939960
ThrowLiteRtLmJniException(
@@ -943,7 +964,7 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
943964

944965
LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
945966
JNIEnv* env, jclass thiz, jlong conversation_pointer,
946-
jstring messageJSONString) {
967+
jstring messageJSONString, jstring extraContextJsonString) {
947968
Conversation* conversation =
948969
reinterpret_cast<Conversation*>(conversation_pointer);
949970

@@ -952,7 +973,15 @@ LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
952973
nlohmann::ordered_json::parse(json_chars);
953974
env->ReleaseStringUTFChars(messageJSONString, json_chars);
954975

955-
auto response = conversation->SendMessage(json_message);
976+
litert::lm::OptionalArgs optional_args;
977+
nlohmann::ordered_json extra_context =
978+
GetExtraContextJson(env, extraContextJsonString);
979+
if (!extra_context.is_null() && !extra_context.empty()) {
980+
optional_args.extra_context = extra_context;
981+
}
982+
983+
auto response =
984+
conversation->SendMessage(json_message, std::move(optional_args));
956985
if (!response.ok()) {
957986
ThrowLiteRtLmJniException(env, "Failed to call nativeSendMessage: " +
958987
response.status().ToString());

runtime/conversation/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ cc_library(
106106
"@com_google_absl//absl/time",
107107
"@nlohmann_json//:json",
108108
"//runtime/components:prompt_template",
109-
"//runtime/components:tokenizer",
110109
"//runtime/components/constrained_decoding:constraint",
111110
"//runtime/components/constrained_decoding:constraint_provider",
112111
"//runtime/components/constrained_decoding:constraint_provider_config",
@@ -137,6 +136,7 @@ cc_test(
137136
":conversation",
138137
":io_types",
139138
"@com_google_googletest//:gtest_main",
139+
"@com_google_absl//absl/container:flat_hash_map",
140140
"@com_google_absl//absl/functional:any_invocable",
141141
"@com_google_absl//absl/status",
142142
"@com_google_absl//absl/status:statusor",

0 commit comments

Comments
 (0)