Skip to content

Commit 24db1e8

Browse files
ai-edge-botcopybara-github
authored andcommitted
Construct ExecutorSettings from a DataStream
LiteRT-LM-PiperOrigin-RevId: 891964124
1 parent 2c96d32 commit 24db1e8

16 files changed

+1219
-46
lines changed

runtime/executor/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ cc_library(
8686
"@com_google_absl//absl/status",
8787
"@com_google_absl//absl/status:statusor",
8888
"@com_google_absl//absl/strings",
89+
"//runtime/util:data_stream",
8990
"//runtime/util:file_util",
9091
"//runtime/util:litert_status_util",
9192
"//runtime/util:memory_mapped_file",
@@ -96,11 +97,13 @@ cc_library(
9697
cc_test(
9798
name = "executor_settings_base_test",
9899
srcs = ["executor_settings_base_test.cc"],
100+
data = ["//runtime/testdata"],
99101
deps = [
100102
":executor_settings_base",
101103
"@com_google_googletest//:gtest_main",
102104
"@com_google_absl//absl/status",
103105
"@com_google_absl//absl/strings:string_view",
106+
"//runtime/util:file_data_stream",
104107
"//runtime/util:memory_mapped_file",
105108
"//runtime/util:test_utils",
106109
],

runtime/executor/executor_settings_base.cc

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "absl/strings/match.h" // from @com_google_absl
2626
#include "absl/strings/str_cat.h" // from @com_google_absl
2727
#include "absl/strings/string_view.h" // from @com_google_absl
28+
#include "runtime/util/data_stream.h"
2829
#include "runtime/util/file_util.h"
2930
#include "runtime/util/memory_mapped_file.h"
3031
#include "runtime/util/scoped_file.h"
@@ -152,6 +153,12 @@ absl::StatusOr<ModelAssets> ModelAssets::Create(
152153
return ModelAssets(std::move(model_file));
153154
}
154155

156+
// static
157+
absl::StatusOr<ModelAssets> ModelAssets::Create(
158+
std::shared_ptr<litert::lm::DataStream> data_stream) {
159+
return ModelAssets(std::move(data_stream));
160+
}
161+
155162
// static
156163
absl::StatusOr<ModelAssets> ModelAssets::Create(
157164
std::shared_ptr<litert::lm::ScopedFile> model_file,
@@ -182,6 +189,9 @@ ModelAssets::ModelAssets(
182189
absl::string_view model_path)
183190
: path_(model_path), memory_mapped_file_(std::move(model_file)) {}
184191

192+
ModelAssets::ModelAssets(std::shared_ptr<litert::lm::DataStream> data_stream)
193+
: data_stream_(std::move(data_stream)) {}
194+
185195
absl::StatusOr<absl::string_view> ModelAssets::GetPath() const {
186196
if (!path_.empty()) {
187197
return path_;
@@ -206,6 +216,14 @@ ModelAssets::GetMemoryMappedFile() const {
206216
return memory_mapped_file_;
207217
}
208218

219+
absl::StatusOr<std::shared_ptr<DataStream>> ModelAssets::GetDataStream() const {
220+
if (!HasDataStream()) {
221+
return absl::InvalidArgumentError(
222+
"Assets were not created with a data stream.");
223+
}
224+
return data_stream_;
225+
}
226+
209227
absl::StatusOr<std::shared_ptr<ScopedFile>> ModelAssets::GetOrCreateScopedFile()
210228
const {
211229
if (HasScopedFile()) {
@@ -222,13 +240,38 @@ absl::StatusOr<std::shared_ptr<ScopedFile>> ModelAssets::GetOrCreateScopedFile()
222240

223241
std::ostream& operator<<(std::ostream& os, const ModelAssets& model_assets) {
224242
if (model_assets.HasScopedFile()) {
225-
os << "model_file file descriptor ID: "
226-
<< model_assets.GetScopedFile().value()->file() << "\n";
243+
auto scoped_file = model_assets.GetScopedFile();
244+
if (scoped_file.ok()) {
245+
os << "model_file file descriptor ID: " << scoped_file.value()->file()
246+
<< "\n";
247+
} else {
248+
os << "model_file error getting ScopedFile: " << scoped_file.status()
249+
<< "\n";
250+
}
227251
} else if (model_assets.HasMemoryMappedFile()) {
228-
os << "model_file memory mapped file: "
229-
<< model_assets.GetMemoryMappedFile().value()->data() << "\n";
252+
auto memory_mapped_file = model_assets.GetMemoryMappedFile();
253+
if (memory_mapped_file.ok()) {
254+
os << "model_file memory mapped file: "
255+
<< memory_mapped_file.value()->data() << "\n";
256+
} else {
257+
os << "model_file error getting MemoryMappedFile: "
258+
<< memory_mapped_file.status() << "\n";
259+
}
260+
} else if (model_assets.HasDataStream()) {
261+
auto data_stream = model_assets.GetDataStream();
262+
if (data_stream.ok()) {
263+
os << "model_file data stream: " << data_stream.value() << "\n";
264+
} else {
265+
os << "model_file error getting DataStream: " << data_stream.status()
266+
<< "\n";
267+
}
230268
} else {
231-
os << "model_path: " << model_assets.GetPath().value() << "\n";
269+
auto model_path = model_assets.GetPath();
270+
if (model_path.ok()) {
271+
os << "model_path: " << model_path.value() << "\n";
272+
} else {
273+
os << "model_path error: " << model_path.status() << "\n";
274+
}
232275
}
233276
os << "fake_weights_mode: " << model_assets.fake_weights_mode() << "\n";
234277
return os;

runtime/executor/executor_settings_base.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "absl/status/status.h" // from @com_google_absl
2525
#include "absl/status/statusor.h" // from @com_google_absl
2626
#include "absl/strings/string_view.h" // from @com_google_absl
27+
#include "runtime/util/data_stream.h"
2728
#include "runtime/util/memory_mapped_file.h"
2829
#include "runtime/util/scoped_file.h"
2930

@@ -114,6 +115,8 @@ class ModelAssets {
114115
static absl::StatusOr<ModelAssets> Create(
115116
std::shared_ptr<MemoryMappedFile> model_file,
116117
absl::string_view model_path);
118+
static absl::StatusOr<ModelAssets> Create(
119+
std::shared_ptr<DataStream> data_stream);
117120

118121
// Convenience factory function to create a ModelAssets with both a model
119122
// path and file. Will use the scoped file if both are provided.
@@ -122,12 +125,14 @@ class ModelAssets {
122125

123126
bool HasScopedFile() const { return scoped_file_ != nullptr; }
124127
bool HasMemoryMappedFile() const { return memory_mapped_file_ != nullptr; }
128+
bool HasDataStream() const { return data_stream_ != nullptr; }
125129

126130
// Returns the model file if it was created with the respective variant,
127131
// otherwise returns an error.
128132
absl::StatusOr<absl::string_view> GetPath() const;
129133
absl::StatusOr<std::shared_ptr<ScopedFile>> GetScopedFile() const;
130134
absl::StatusOr<std::shared_ptr<MemoryMappedFile>> GetMemoryMappedFile() const;
135+
absl::StatusOr<std::shared_ptr<DataStream>> GetDataStream() const;
131136

132137
// Convenience method to get a read-only scoped file to the model file
133138
// regardless of whether this instance was created from a path or scoped file.
@@ -146,12 +151,14 @@ class ModelAssets {
146151
explicit ModelAssets(std::shared_ptr<MemoryMappedFile> model_file);
147152
explicit ModelAssets(std::shared_ptr<MemoryMappedFile> model_file,
148153
absl::string_view model_path);
154+
explicit ModelAssets(std::shared_ptr<DataStream> data_stream);
149155

150156
// TODO: b/417814685 - Consider supporting multiple model files if the need
151157
// case arises.
152158
std::string path_;
153159
std::shared_ptr<ScopedFile> scoped_file_;
154160
std::shared_ptr<MemoryMappedFile> memory_mapped_file_;
161+
std::shared_ptr<DataStream> data_stream_;
155162

156163
FakeWeightsMode fake_weights_mode_ = FakeWeightsMode::FAKE_WEIGHTS_NONE;
157164
};

runtime/executor/executor_settings_base_test.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "runtime/executor/executor_settings_base.h"
1616

17+
#include <filesystem> // NOLINT
1718
#include <memory>
1819
#include <sstream>
1920
#include <string>
@@ -24,12 +25,20 @@
2425
#include <gtest/gtest.h>
2526
#include "absl/status/status.h" // from @com_google_absl
2627
#include "absl/strings/string_view.h" // from @com_google_absl
28+
#include "runtime/util/file_data_stream.h"
2729
#include "runtime/util/memory_mapped_file.h"
2830
#include "runtime/util/test_utils.h" // NOLINT
2931

3032
namespace litert::lm {
3133
namespace {
3234

35+
std::string GetTestModelPath() {
36+
const auto model_path =
37+
std::filesystem::path(::testing::SrcDir()) /
38+
"litert_lm/runtime/testdata/test_lm.litertlm";
39+
return model_path.string();
40+
}
41+
3342
TEST(LlmExecutorConfigTest, Backend) {
3443
Backend backend;
3544
std::stringstream oss;
@@ -175,6 +184,22 @@ TEST(LlmExecutorConfigTest, ModelAssetsMemoryMapped) {
175184
EXPECT_THAT(oss.str(), testing::HasSubstr("FAKE_WEIGHTS_NONE"));
176185
}
177186

187+
TEST(LlmExecutorConfigTest, ModelAssetsDataStream) {
188+
ASSERT_OK_AND_ASSIGN(auto data_stream,
189+
FileDataStream::Create(GetTestModelPath()));
190+
191+
auto model_assets = ModelAssets::Create(data_stream);
192+
ASSERT_OK(model_assets);
193+
EXPECT_TRUE(model_assets->HasDataStream());
194+
ASSERT_OK_AND_ASSIGN(auto retrieved_stream, model_assets->GetDataStream());
195+
EXPECT_EQ(retrieved_stream, data_stream);
196+
197+
std::stringstream oss;
198+
oss << *model_assets;
199+
EXPECT_THAT(oss.str(), testing::HasSubstr("model_file data stream"));
200+
EXPECT_THAT(oss.str(), testing::HasSubstr("FAKE_WEIGHTS_NONE"));
201+
}
202+
178203
class TestExecutorSettings : public ExecutorSettingsBase {
179204
public:
180205
explicit TestExecutorSettings(ModelAssets model_assets)

runtime/util/BUILD

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ cc_test(
6565
],
6666
)
6767

68+
cc_library(
69+
name = "data_stream",
70+
srcs = ["data_stream.cc"],
71+
hdrs = ["data_stream.h"],
72+
deps = [
73+
":litert_status_util",
74+
"@com_google_absl//absl/status",
75+
"@com_google_absl//absl/status:statusor",
76+
"@com_google_absl//absl/strings",
77+
],
78+
)
79+
80+
cc_test(
81+
name = "data_stream_test",
82+
srcs = ["data_stream_test.cc"],
83+
data = ["//runtime/testdata"],
84+
deps = [
85+
":data_stream",
86+
":file_data_stream",
87+
":test_utils",
88+
"@com_google_googletest//:gtest_main",
89+
"@com_google_absl//absl/status",
90+
"@com_google_absl//absl/status:statusor",
91+
],
92+
)
93+
6894
cc_library(
6995
name = "log_tensor_buffer",
7096
srcs = ["log_tensor_buffer.cc"],
@@ -154,6 +180,32 @@ cc_test(
154180
],
155181
)
156182

183+
cc_library(
184+
name = "file_data_stream",
185+
srcs = ["file_data_stream.cc"],
186+
hdrs = ["file_data_stream.h"],
187+
deps = [
188+
":data_stream",
189+
"@com_google_absl//absl/status",
190+
"@com_google_absl//absl/status:statusor",
191+
"@com_google_absl//absl/strings",
192+
],
193+
)
194+
195+
cc_test(
196+
name = "file_data_stream_test",
197+
srcs = ["file_data_stream_test.cc"],
198+
data = ["//runtime/testdata"],
199+
deps = [
200+
":file_data_stream",
201+
":test_utils",
202+
"@com_google_googletest//:gtest_main",
203+
"@com_google_absl//absl/status",
204+
"@com_google_absl//absl/status:status_matchers",
205+
"@com_google_absl//absl/status:statusor",
206+
],
207+
)
208+
157209
cc_library(
158210
name = "file_format_util",
159211
srcs = ["file_format_util.cc"],
@@ -472,6 +524,43 @@ cc_library(
472524
],
473525
)
474526

527+
cc_library(
528+
name = "litert_lm_streaming_loader",
529+
srcs = ["litert_lm_streaming_loader.cc"],
530+
hdrs = ["litert_lm_streaming_loader.h"],
531+
deps = [
532+
":data_stream",
533+
":litert_lm_loader",
534+
":litert_status_util",
535+
"@com_google_absl//absl/log:absl_check",
536+
"@com_google_absl//absl/log:absl_log",
537+
"@com_google_absl//absl/status",
538+
"@com_google_absl//absl/status:statusor",
539+
"@com_google_absl//absl/strings",
540+
"@litert//litert/cc:litert_buffer_ref",
541+
"//runtime/components:model_resources",
542+
"//schema/core:litertlm_header_schema",
543+
"//schema/core:litertlm_read",
544+
],
545+
)
546+
547+
cc_test(
548+
name = "litert_lm_streaming_loader_test",
549+
srcs = ["litert_lm_streaming_loader_test.cc"],
550+
data = ["//runtime/testdata"],
551+
deps = [
552+
":file_data_stream",
553+
":litert_lm_streaming_loader",
554+
":test_utils",
555+
"@com_google_googletest//:gtest_main",
556+
"@com_google_absl//absl/status",
557+
"@com_google_absl//absl/status:status_matchers",
558+
"//runtime/util:data_stream",
559+
"//schema/core:litertlm_header_schema",
560+
"//schema/core:litertlm_read",
561+
],
562+
)
563+
475564
cc_library(
476565
name = "metadata_util",
477566
srcs = ["metadata_util.cc"],

0 commit comments

Comments
 (0)