Skip to content

Commit 65f07c3

Browse files
ai-edge-botcopybara-github
authored andcommitted
Internal changes and clean up.
LiteRT-LM-PiperOrigin-RevId: 890153576
1 parent 79fbcf1 commit 65f07c3

19 files changed

+1770
-613
lines changed

runtime/components/embedding_lookup/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ cc_library(
6363
"@litert//litert/cc:litert_options",
6464
"@litert//litert/cc:litert_ranked_tensor_type",
6565
"@litert//litert/cc:litert_tensor_buffer",
66+
"@litert//litert/cc/options:litert_qualcomm_options",
6667
],
6768
}),
6869
)
@@ -127,6 +128,8 @@ cc_library(
127128
"@litert//litert/cc:litert_model",
128129
"@litert//litert/cc:litert_options",
129130
"@litert//litert/cc:litert_tensor_buffer",
131+
"@litert//litert/cc/internal:litert_handle",
132+
"@litert//litert/cc/options:litert_qualcomm_options",
130133
],
131134
}),
132135
)
@@ -158,6 +161,7 @@ cc_library(
158161
"@litert//litert/cc:litert_api_with_dynamic_runtime",
159162
],
160163
"//conditions:default": [
164+
"@litert//litert/cc:litert_environment",
161165
"@litert//litert/cc:litert_model",
162166
"@litert//litert/cc:litert_tensor_buffer",
163167
],

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/status/statusor.h" // from @com_google_absl
2828
#include "absl/strings/str_cat.h" // from @com_google_absl
2929
#include "absl/types/span.h" // from @com_google_absl
30+
#include "litert/cc/internal/litert_handle.h" // from @litert
3031
#include "litert/cc/litert_common.h" // from @litert
3132
#include "litert/cc/litert_compiled_model.h" // from @litert
3233
#include "litert/cc/litert_element_type.h" // from @litert
@@ -36,6 +37,9 @@
3637
#include "litert/cc/litert_options.h" // from @litert
3738
#include "litert/cc/litert_tensor_buffer.h" // from @litert
3839
#include "runtime/util/status_macros.h" //NOLINT
40+
#if defined(__ANDROID__)
41+
#include "litert/cc/options/litert_qualcomm_options.h" // from @litert
42+
#endif
3943

4044
namespace litert::lm {
4145

@@ -149,18 +153,42 @@ absl::Status EndOfMultiModalEmbedding::LookupPrefill(
149153

150154
absl::StatusOr<std::unique_ptr<EndOfMultiModalEmbedding>>
151155
EndOfMultiModalEmbedding::Create(const litert::Model* absl_nonnull model,
152-
int special_token) {
153-
LITERT_ASSIGN_OR_RETURN(auto env, ::litert::Environment::Create({}));
154-
auto handler = std::unique_ptr<EndOfMultiModalEmbedding>(
155-
new EndOfMultiModalEmbedding(std::move(env), model, special_token));
156+
int special_token, litert::Environment* env) {
157+
if (env == nullptr) {
158+
LITERT_ASSIGN_OR_RETURN(auto local_env, ::litert::Environment::Create({}));
159+
auto handler =
160+
std::unique_ptr<EndOfMultiModalEmbedding>(new EndOfMultiModalEmbedding(
161+
std::move(local_env), model, special_token));
162+
RETURN_IF_ERROR(handler->Initialize());
163+
return handler;
164+
}
165+
auto handler =
166+
std::unique_ptr<EndOfMultiModalEmbedding>(new EndOfMultiModalEmbedding(
167+
::litert::Environment::WrapCObject(
168+
env->GetHolder(),
169+
::litert::OwnHandle::kNo), // NOLINT(build/include_what_you_use)
170+
model, special_token));
156171
RETURN_IF_ERROR( // IWYU pragma: keep as is included by status_macros.h
157172
handler->Initialize());
158173
return handler;
159174
}
160175

161176
absl::Status EndOfMultiModalEmbedding::Initialize() {
162177
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
178+
#if defined(__ANDROID__)
179+
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
180+
litert::HwAccelerators::kCpu);
181+
#else
163182
options.SetHardwareAccelerators(litert::HwAccelerators::kCpu);
183+
#endif
184+
#if defined(__ANDROID__)
185+
LITERT_ASSIGN_OR_RETURN(::litert::qualcomm::QualcommOptions & qnn_opts,
186+
options.GetQualcommOptions());
187+
qnn_opts.SetLogLevel(::litert::qualcomm::QualcommOptions::LogLevel::kOff);
188+
qnn_opts.SetHtpPerformanceMode(
189+
::litert::qualcomm::QualcommOptions::HtpPerformanceMode::
190+
kSustainedHighPerformance);
191+
#endif
164192

165193
LITERT_ASSIGN_OR_RETURN(
166194
litert::CompiledModel compiled_model,

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "litert/cc/litert_environment.h" // from @litert
3030
#include "litert/cc/litert_layout.h" // from @litert
3131
#include "litert/cc/litert_model.h" // from @litert
32+
#include "litert/cc/litert_options.h" // from @litert
3233
#include "litert/cc/litert_tensor_buffer.h" // from @litert
3334
#include "runtime/components/embedding_lookup/embedding_lookup.h"
3435

@@ -45,7 +46,8 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
4546
// multi-modal embedding. If the special token is not found in the tokens,
4647
// the end of multi-modal embedding will not be inserted.
4748
static absl::StatusOr<std::unique_ptr<EndOfMultiModalEmbedding>> Create(
48-
const litert::Model* absl_nonnull model, int special_token);
49+
const litert::Model* absl_nonnull model, int special_token,
50+
litert::Environment* env = nullptr);
4951

5052
// Multimodal embeddings are not supported during decode.
5153
absl::Status LookupDecode(int token,

runtime/components/embedding_lookup/embedding_lookup_manager.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "absl/status/statusor.h" // from @com_google_absl
3030
#include "absl/strings/string_view.h" // from @com_google_absl
3131
#include "absl/types/span.h" // from @com_google_absl
32+
#include "litert/cc/litert_environment.h" // from @litert
3233
#include "litert/cc/litert_macros.h" // from @litert
3334
#include "litert/cc/litert_model.h" // from @litert
3435
#include "litert/cc/litert_tensor_buffer.h" // from @litert
@@ -45,22 +46,24 @@ EmbeddingLookupManager::Create(
4546
const litert::Model* absl_nonnull text_embedding_model,
4647
absl::flat_hash_map<int, const litert::Model*>&
4748
end_of_multi_modal_embedding_models,
48-
bool fully_supports_multi_modal, std::optional<std::string> signature_key) {
49+
bool fully_supports_multi_modal, std::optional<std::string> signature_key,
50+
litert::Environment* env) {
4951
auto embedding_lookup_manager = std::make_unique<EmbeddingLookupManager>();
5052
RETURN_IF_ERROR(embedding_lookup_manager->Initialize(
5153
text_embedding_model, end_of_multi_modal_embedding_models,
52-
fully_supports_multi_modal, signature_key));
54+
fully_supports_multi_modal, signature_key, env));
5355
return std::move(embedding_lookup_manager);
5456
}
5557

5658
absl::StatusOr<std::unique_ptr<EmbeddingLookupManager>>
5759
EmbeddingLookupManager::Create(
5860
const litert::Model* absl_nonnull text_embedding_model,
59-
bool fully_supports_multi_modal, std::optional<std::string> signature_key) {
61+
bool fully_supports_multi_modal, std::optional<std::string> signature_key,
62+
litert::Environment* env) {
6063
absl::flat_hash_map<int, const litert::Model*>
6164
end_of_multi_modal_embedding_models;
6265
return Create(text_embedding_model, end_of_multi_modal_embedding_models,
63-
fully_supports_multi_modal, signature_key);
66+
fully_supports_multi_modal, signature_key, env);
6467
}
6568

6669
absl::Status EmbeddingLookupManager::UpdateMultiModalEmbeddings(
@@ -239,7 +242,8 @@ absl::Status EmbeddingLookupManager::Initialize(
239242
const litert::Model* absl_nonnull text_embedding_model,
240243
absl::flat_hash_map<int, const litert::Model*>&
241244
end_of_multi_modal_embedding_models,
242-
bool fully_supports_multi_modal, std::optional<std::string> signature_key) {
245+
bool fully_supports_multi_modal, std::optional<std::string> signature_key,
246+
litert::Environment* env) {
243247
if (!fully_supports_multi_modal &&
244248
!end_of_multi_modal_embedding_models.empty()) {
245249
return absl::InvalidArgumentError(
@@ -249,12 +253,12 @@ absl::Status EmbeddingLookupManager::Initialize(
249253
fully_supports_multi_modal_ = fully_supports_multi_modal;
250254
ASSIGN_OR_RETURN(text_embedding_lookup_,
251255
EmbeddingLookupText::Create(std::move(text_embedding_model),
252-
signature_key));
256+
signature_key, env));
253257
for (const auto& [special_token, embedding_model] :
254258
end_of_multi_modal_embedding_models) {
255259
ASSIGN_OR_RETURN(auto end_of_multi_modal_embedding_lookup,
256260
EndOfMultiModalEmbedding::Create(
257-
std::move(embedding_model), special_token));
261+
std::move(embedding_model), special_token, env));
258262
end_of_multi_modal_embedding_lookups_.push_back(
259263
std::move(end_of_multi_modal_embedding_lookup));
260264
}

runtime/components/embedding_lookup/embedding_lookup_manager.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ class EmbeddingLookupManager {
5656
absl::flat_hash_map<int, const litert::Model*>&
5757
end_of_multi_modal_embedding_models,
5858
bool fully_supports_multi_modal = true,
59-
std::optional<std::string> signature_key = std::nullopt);
59+
std::optional<std::string> signature_key = std::nullopt,
60+
litert::Environment* env = nullptr);
6061

6162
static absl::StatusOr<std::unique_ptr<EmbeddingLookupManager>> Create(
6263
const litert::Model* absl_nonnull text_embedding_model,
6364
bool fully_supports_multi_modal = true,
64-
std::optional<std::string> signature_key = std::nullopt);
65+
std::optional<std::string> signature_key = std::nullopt,
66+
litert::Environment* env = nullptr);
6567

6668
// Updates the multimodal embeddings for the given ExecutorInputs.
6769
// Intended to be called at the beginning of the prefill pass.
@@ -118,8 +120,8 @@ class EmbeddingLookupManager {
118120
const litert::Model* absl_nonnull text_embedding_model,
119121
absl::flat_hash_map<int, const litert::Model*>&
120122
end_of_multi_modal_embedding_models,
121-
bool fully_supports_multi_modal,
122-
std::optional<std::string> signature_key);
123+
bool fully_supports_multi_modal, std::optional<std::string> signature_key,
124+
litert::Environment* env = nullptr);
123125

124126
std::unique_ptr<EmbeddingLookupText> text_embedding_lookup_;
125127
std::vector<std::unique_ptr<EmbeddingLookupMultiModal>>

runtime/components/embedding_lookup/embedding_lookup_text.cc

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
#include "litert/cc/litert_options.h" // from @litert
4141
#include "litert/cc/litert_tensor_buffer.h" // from @litert
4242
#include "runtime/util/status_macros.h" //NOLINT
43+
#if defined(__ANDROID__)
44+
#include "litert/cc/options/litert_qualcomm_options.h" // from @litert
45+
#endif
4346

4447
namespace litert::lm {
4548

@@ -243,17 +246,39 @@ absl::Status EmbeddingLookupText::LookupPrefill(absl::Span<const int> tokens,
243246

244247
absl::StatusOr<std::unique_ptr<EmbeddingLookupText>>
245248
EmbeddingLookupText::Create(const litert::Model* absl_nonnull model,
246-
std::optional<std::string> signature_key) {
247-
LITERT_ASSIGN_OR_RETURN(auto env, ::litert::Environment::Create({}));
248-
auto handler = std::unique_ptr<EmbeddingLookupText>(
249-
new EmbeddingLookupText(std::move(env), model, signature_key));
249+
std::optional<std::string> signature_key,
250+
litert::Environment* env) {
251+
if (env == nullptr) {
252+
LITERT_ASSIGN_OR_RETURN(auto local_env, ::litert::Environment::Create({}));
253+
auto handler = std::unique_ptr<EmbeddingLookupText>(
254+
new EmbeddingLookupText(std::move(local_env), model, signature_key));
255+
RETURN_IF_ERROR(handler->Initialize());
256+
return handler;
257+
}
258+
auto handler = std::unique_ptr<EmbeddingLookupText>(new EmbeddingLookupText(
259+
::litert::Environment::WrapCObject(env->GetHolder(),
260+
::litert::OwnHandle::kNo), // NOLINT
261+
model, signature_key));
250262
RETURN_IF_ERROR(handler->Initialize());
251263
return handler;
252264
}
253265

254266
absl::Status EmbeddingLookupText::Initialize() {
255267
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
268+
#if defined(__ANDROID__)
269+
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
270+
litert::HwAccelerators::kCpu);
271+
#else
256272
options.SetHardwareAccelerators(litert::HwAccelerators::kCpu);
273+
#endif
274+
#if defined(__ANDROID__)
275+
LITERT_ASSIGN_OR_RETURN(::litert::qualcomm::QualcommOptions & qnn_opts,
276+
options.GetQualcommOptions());
277+
qnn_opts.SetLogLevel(::litert::qualcomm::QualcommOptions::LogLevel::kOff);
278+
qnn_opts.SetHtpPerformanceMode(
279+
::litert::qualcomm::QualcommOptions::HtpPerformanceMode::
280+
kSustainedHighPerformance);
281+
#endif
257282

258283
LITERT_ASSIGN_OR_RETURN(compiled_model_, litert::CompiledModel::Create(
259284
env_, model_.Get(), options));
@@ -323,6 +348,11 @@ absl::Status EmbeddingLookupText::Initialize() {
323348
floats_per_token_output_ *= output_buffer_layout.Dimensions()[i];
324349
}
325350

351+
ABSL_LOG(INFO) << "EmbeddingLookupText initialized: "
352+
<< "signature=" << signature_key_.value_or("default")
353+
<< ", rank=" << output_buffer_layout.Rank()
354+
<< ", floats_per_token=" << floats_per_token_output_;
355+
326356
// Initialize the default embedding vector to be the embedding of token 0.
327357
default_embedding_vector_.resize(floats_per_token_output_);
328358
RETURN_IF_ERROR(LookupInternal(

runtime/components/embedding_lookup/embedding_lookup_text.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "litert/cc/litert_compiled_model.h" // from @litert
3333
#include "litert/cc/litert_environment.h" // from @litert
3434
#include "litert/cc/litert_model.h" // from @litert
35+
#include "litert/cc/litert_options.h" // from @litert
3536
#include "litert/cc/litert_ranked_tensor_type.h" // from @litert
3637
#include "litert/cc/litert_tensor_buffer.h" // from @litert
3738
#include "runtime/components/embedding_lookup/embedding_lookup.h"
@@ -55,7 +56,8 @@ class EmbeddingLookupText : public EmbeddingLookup {
5556
// signature_key is not provided, the first signature will be used by default.
5657
static absl::StatusOr<std::unique_ptr<EmbeddingLookupText>> Create(
5758
const litert::Model* absl_nonnull model,
58-
std::optional<std::string> signature_key = std::nullopt);
59+
std::optional<std::string> signature_key = std::nullopt,
60+
litert::Environment* env = nullptr);
5961

6062
// For a given token, looks up the embedding and stores it in the
6163
// provided vector. The caller is responsible for ensuring that the vector is

runtime/engine/litert_lm_advanced_main.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,10 @@
3030
#include <string>
3131
#include <vector>
3232

33-
#include "absl/base/log_severity.h" // from @com_google_absl
3433
#include "absl/flags/flag.h" // from @com_google_absl
35-
#include "absl/flags/marshalling.h" // from @com_google_absl
3634
#include "absl/flags/parse.h" // from @com_google_absl
3735
#include "absl/log/absl_check.h" // from @com_google_absl
3836
#include "absl/log/absl_log.h" // from @com_google_absl
39-
#include "absl/log/globals.h" // from @com_google_absl
4037
#include "absl/status/status.h" // from @com_google_absl
4138
#include "absl/status/statusor.h" // from @com_google_absl
4239
#include "absl/strings/numbers.h" // from @com_google_absl
@@ -257,6 +254,7 @@ absl::Status MainHelper(int argc, char** argv) {
257254
settings.enable_speculative_decoding =
258255
absl::GetFlag(FLAGS_enable_speculative_decoding);
259256

257+
260258
// Adjust max_num_tokens and prefill_batch_size if not set on benchmark mode.
261259
if (settings.benchmark && settings.benchmark_prefill_tokens > 0) {
262260
if (settings.max_num_tokens == 0 && settings.benchmark_decode_tokens > 0) {

runtime/engine/litert_lm_lib.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "absl/strings/str_format.h" // from @com_google_absl
4646
#include "absl/strings/string_view.h" // from @com_google_absl
4747
#include "absl/time/time.h" // from @com_google_absl
48+
#include "runtime/util/status_macros.h" // NOLINT
4849
#include "nlohmann/json.hpp" // from @nlohmann_json
4950
#include "litert/cc/internal/scoped_file.h" // from @litert
5051
#include "runtime/components/constrained_decoding/constraint.h"
@@ -61,7 +62,6 @@
6162
#include "runtime/executor/llm_executor_settings.h"
6263
#include "runtime/proto/sampler_params.pb.h"
6364
#include "runtime/util/scoped_file.h"
64-
#include "runtime/util/status_macros.h" // IWYU pragma: keep
6565
#include "re2/re2.h" // from @com_googlesource_code_re2
6666
#include "tflite/profiling/memory_info.h" // from @litert
6767
#include "tflite/profiling/memory_usage_monitor.h" // from @litert

0 commit comments

Comments
 (0)