Skip to content

Commit 0c66507

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add a streaming-aware version of LitertLmLoader
LitertLmStreamingLoader provides utility functions for reading a .litertlm model from a DataStream. LiteRT-LM-PiperOrigin-RevId: 888363584
1 parent f434519 commit 0c66507

File tree

9 files changed

+499
-75
lines changed

9 files changed

+499
-75
lines changed

runtime/util/BUILD

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ cc_test(
8787
":test_utils",
8888
"@com_google_googletest//:gtest_main",
8989
"@com_google_absl//absl/status",
90-
"@com_google_absl//absl/status:statusor",
9190
],
9291
)
9392

@@ -524,6 +523,43 @@ cc_library(
524523
],
525524
)
526525

526+
cc_library(
527+
name = "litert_lm_streaming_loader",
528+
srcs = ["litert_lm_streaming_loader.cc"],
529+
hdrs = ["litert_lm_streaming_loader.h"],
530+
deps = [
531+
":data_stream",
532+
":litert_lm_loader",
533+
":litert_status_util",
534+
"@com_google_absl//absl/log:absl_check",
535+
"@com_google_absl//absl/log:absl_log",
536+
"@com_google_absl//absl/status",
537+
"@com_google_absl//absl/status:statusor",
538+
"@com_google_absl//absl/strings",
539+
"@litert//litert/cc:litert_buffer_ref",
540+
"//runtime/components:model_resources",
541+
"//schema/core:litertlm_header_schema",
542+
"//schema/core:litertlm_read",
543+
],
544+
)
545+
546+
cc_test(
547+
name = "litert_lm_streaming_loader_test",
548+
srcs = ["litert_lm_streaming_loader_test.cc"],
549+
data = ["//runtime/testdata"],
550+
deps = [
551+
":file_data_stream",
552+
":litert_lm_streaming_loader",
553+
":test_utils",
554+
"@com_google_googletest//:gtest_main",
555+
"@com_google_absl//absl/status",
556+
"@com_google_absl//absl/status:status_matchers",
557+
"//runtime/util:data_stream",
558+
"//schema/core:litertlm_header_schema",
559+
"//schema/core:litertlm_read",
560+
],
561+
)
562+
527563
cc_library(
528564
name = "metadata_util",
529565
srcs = ["metadata_util.cc"],

runtime/util/data_stream.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,36 @@
2727
namespace litert::lm {
2828

2929
SubStream::~SubStream() {
30-
if (auto parent = parent_.lock()) {
30+
if (parent_) {
3131
// Note: We ignore errors here as we are in a destructor.
32-
(void)parent->Discard(offset_, size_);
32+
(void)parent_->Discard(offset_, size_);
3333
}
3434
}
3535

3636
absl::Status SubStream::ReadAndDiscard(void* buffer, uint64_t offset,
3737
uint64_t size) {
3838
RETURN_IF_ERROR(CheckBounds(offset, size));
39-
if (auto parent = parent_.lock()) {
40-
return parent->ReadAndDiscard(buffer, offset_ + offset, size);
39+
if (parent_) {
40+
return parent_->ReadAndDiscard(buffer, offset_ + offset, size);
4141
}
42-
return absl::FailedPreconditionError("Parent stream is expired");
42+
return absl::FailedPreconditionError("Parent stream is null");
4343
}
4444

4545
absl::Status SubStream::ReadAndPreserve(void* buffer, uint64_t offset,
4646
uint64_t size) {
4747
RETURN_IF_ERROR(CheckBounds(offset, size));
48-
if (auto parent = parent_.lock()) {
49-
return parent->ReadAndPreserve(buffer, offset_ + offset, size);
48+
if (parent_) {
49+
return parent_->ReadAndPreserve(buffer, offset_ + offset, size);
5050
}
51-
return absl::FailedPreconditionError("Parent stream is expired");
51+
return absl::FailedPreconditionError("Parent stream is null");
5252
}
5353

5454
absl::Status SubStream::Discard(uint64_t offset, uint64_t size) {
5555
RETURN_IF_ERROR(CheckBounds(offset, size));
56-
if (auto parent = parent_.lock()) {
57-
return parent->Discard(offset_ + offset, size);
56+
if (parent_) {
57+
return parent_->Discard(offset_ + offset, size);
5858
}
59-
return absl::FailedPreconditionError("Parent stream is expired");
59+
return absl::FailedPreconditionError("Parent stream is null");
6060
}
6161

6262
absl::Status SubStream::CheckBounds(uint64_t offset, uint64_t size) const {
@@ -69,7 +69,7 @@ absl::Status SubStream::CheckBounds(uint64_t offset, uint64_t size) const {
6969
return absl::OkStatus();
7070
}
7171

72-
absl::StatusOr<std::shared_ptr<DataStream>> SubStream::OpenSubStream(
72+
absl::StatusOr<std::unique_ptr<DataStream>> SubStream::OpenSubStream(
7373
uint64_t offset, uint64_t size) {
7474
// Check if the requested substream fits within this SubStream's bounds.
7575
// Note that the parent DataStream::OpenSubStream method doesn't do this for
@@ -80,7 +80,7 @@ absl::StatusOr<std::shared_ptr<DataStream>> SubStream::OpenSubStream(
8080
return DataStream::OpenSubStream(offset, size);
8181
}
8282

83-
absl::StatusOr<std::shared_ptr<DataStream>> DataStream::OpenSubStream(
83+
absl::StatusOr<std::unique_ptr<DataStream>> DataStream::OpenSubStream(
8484
uint64_t offset, uint64_t size) {
8585
for (const auto& region : locked_regions_) {
8686
// Check for overlap: Is [offset, offset + size) overlapping with
@@ -94,7 +94,7 @@ absl::StatusOr<std::shared_ptr<DataStream>> DataStream::OpenSubStream(
9494
}
9595
}
9696
locked_regions_.emplace_back(offset, size);
97-
return std::make_shared<SubStream>(shared_from_this(), offset, size);
97+
return std::make_unique<SubStream>(this, offset, size);
9898
}
9999

100100
} // namespace litert::lm

runtime/util/data_stream.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace litert::lm {
3232
// A `SubStream` is a view into a parent `DataStream`. It holds a weak pointer
3333
// to its parent, so it does not keep the parent alive. The user must ensure
3434
// that the parent `DataStream` outlives all of its `SubStream`s.
35-
class DataStream : public std::enable_shared_from_this<DataStream> {
35+
class DataStream {
3636
public:
3737
virtual ~DataStream() = default;
3838

@@ -55,7 +55,7 @@ class DataStream : public std::enable_shared_from_this<DataStream> {
5555
// `offset` + `size`).
5656
// Note that substreams cannot overlap regions in the parent stream, even if
5757
// the first substream is destroyed.
58-
virtual absl::StatusOr<std::shared_ptr<DataStream>> OpenSubStream(
58+
virtual absl::StatusOr<std::unique_ptr<DataStream>> OpenSubStream(
5959
uint64_t offset, uint64_t size);
6060

6161
private:
@@ -68,7 +68,7 @@ class DataStream : public std::enable_shared_from_this<DataStream> {
6868
// that the parent `DataStream` outlives all of its `SubStream`s.
6969
class SubStream : public DataStream {
7070
public:
71-
SubStream(std::shared_ptr<DataStream> parent, uint64_t offset, uint64_t size)
71+
SubStream(DataStream* parent, uint64_t offset, uint64_t size)
7272
: parent_(parent), offset_(offset), size_(size) {}
7373

7474
~SubStream() override;
@@ -81,11 +81,11 @@ class SubStream : public DataStream {
8181

8282
absl::Status Discard(uint64_t offset, uint64_t size) override;
8383

84-
absl::StatusOr<std::shared_ptr<DataStream>> OpenSubStream(
84+
absl::StatusOr<std::unique_ptr<DataStream>> OpenSubStream(
8585
uint64_t offset, uint64_t size) override;
8686

8787
private:
88-
std::weak_ptr<DataStream> parent_;
88+
DataStream* parent_;
8989
uint64_t offset_;
9090
uint64_t size_;
9191

runtime/util/data_stream_test.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,6 @@ TEST(DataStreamTest, SubStreamOfSubStream) {
104104
EXPECT_EQ(magic_number, "ITERTLM");
105105
}
106106

107-
TEST(DataStreamTest, SubStreamReadAfterParentDestroyed) {
108-
std::shared_ptr<DataStream> sub_stream;
109-
{
110-
ASSERT_OK_AND_ASSIGN(auto stream,
111-
FileDataStream::Create(GetTestModelPath()));
112-
ASSERT_OK_AND_ASSIGN(sub_stream, stream->OpenSubStream(0, 8));
113-
}
114-
115-
std::vector<char> buffer(1);
116-
absl::Status status = sub_stream->ReadAndPreserve(buffer.data(), 0, 1);
117-
EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
118-
EXPECT_THAT(status.message(), HasSubstr("Parent stream is expired"));
119-
}
120-
121107
TEST(DataStreamTest, SubStreamCannotOverlap) {
122108
ASSERT_OK_AND_ASSIGN(auto stream, FileDataStream::Create(GetTestModelPath()));
123109
ASSERT_OK_AND_ASSIGN(auto sub_stream1, stream->OpenSubStream(0, 10));

runtime/util/litert_lm_loader.cc

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,48 @@ absl::StatusOr<std::unique_ptr<MemoryMappedFile>> CreateMemoryMapFromScopedFile(
5959
"whole");
6060
}
6161

62-
constexpr uint64_t kLitertLmHeaderMaxSize = 16 * 1024;
63-
6462
} // namespace
6563

64+
absl::StatusOr<std::pair<BufferKey, std::optional<std::string>>>
65+
ExtractBufferKeyAndBackendConstraint(const schema::SectionObject* section) {
66+
auto items = section->items();
67+
BufferKey buffer_key(section->data_type());
68+
std::optional<std::string> backend_constraint;
69+
// Extract the specific model type from the section items KeyValuePairs.
70+
if (section->data_type() == schema::AnySectionDataType_TFLiteModel ||
71+
section->data_type() == schema::AnySectionDataType_TFLiteWeights) {
72+
bool found_model_type = false;
73+
std::string model_type;
74+
for (size_t j = 0; j < items->size(); ++j) {
75+
auto item = items->Get(j);
76+
if (item->key() &&
77+
absl::AsciiStrToLower(item->key()->str()) == "model_type" &&
78+
item->value()) {
79+
found_model_type = true;
80+
model_type = *(item->value_as_StringValue()->value());
81+
}
82+
if (item->key() &&
83+
absl::AsciiStrToLower(item->key()->str()) == "backend_constraint" &&
84+
item->value()) {
85+
backend_constraint = *(item->value_as_StringValue()->value());
86+
}
87+
}
88+
if (found_model_type) {
89+
ABSL_LOG(INFO) << "model_type: " << model_type;
90+
ASSIGN_OR_RETURN(ModelType model_type_enum,
91+
StringToModelType(model_type));
92+
buffer_key = BufferKey(section->data_type(), model_type_enum);
93+
} else {
94+
ABSL_LOG(WARNING) << "model_type not found, use kTfLitePrefillDecode";
95+
// For backward compatibility, we will use the default model type if
96+
// model_type is not found.
97+
buffer_key =
98+
BufferKey(section->data_type(), ModelType::kTfLitePrefillDecode);
99+
}
100+
}
101+
return std::make_pair(buffer_key, backend_constraint);
102+
}
103+
66104
absl::Status LitertLmLoader::MapSection(BufferKey buffer_key,
67105
uint64_t begin_offset,
68106
uint64_t end_offset) {
@@ -149,7 +187,6 @@ absl::Status LitertLmLoader::Initialize() {
149187
ABSL_LOG(INFO) << "mmap_status is ok";
150188

151189
// Read the header information.
152-
schema::LitertlmHeader header;
153190
absl::Status status =
154191
ReadHeaderFromLiteRTLM(header_data, header_size, &header_);
155192
ABSL_LOG(INFO) << "status: " << status;
@@ -164,44 +201,14 @@ absl::Status LitertLmLoader::Initialize() {
164201
auto sections = header_.metadata->section_metadata()->objects();
165202
for (size_t i = 0; i < sections->size(); ++i) {
166203
const schema::SectionObject* section = sections->Get(i);
167-
auto items = section->items();
168-
BufferKey buffer_key(section->data_type());
169-
// Extract the specific model type from the section items KeyValuePairs.
170-
if (section->data_type() == schema::AnySectionDataType_TFLiteModel ||
171-
section->data_type() == schema::AnySectionDataType_TFLiteWeights) {
172-
bool found_model_type = false;
173-
std::string model_type;
174-
std::string backend_constraint;
175-
for (size_t j = 0; j < items->size(); ++j) {
176-
auto item = items->Get(j);
177-
if (item->key() &&
178-
absl::AsciiStrToLower(item->key()->str()) == "model_type" &&
179-
item->value()) {
180-
found_model_type = true;
181-
model_type = *(item->value_as_StringValue()->value());
182-
}
183-
if (item->key() &&
184-
absl::AsciiStrToLower(item->key()->str()) == "backend_constraint" &&
185-
item->value()) {
186-
backend_constraint = *(item->value_as_StringValue()->value());
187-
}
188-
}
189-
if (found_model_type) {
190-
ABSL_LOG(INFO) << "model_type: " << model_type;
191-
ASSIGN_OR_RETURN(ModelType model_type_enum,
192-
StringToModelType(model_type));
193-
buffer_key = BufferKey(section->data_type(), model_type_enum);
194-
} else {
195-
ABSL_LOG(WARNING) << "model_type not found, use kTfLitePrefillDecode";
196-
// For backward compatibility, we will use the default model type if
197-
// model_type is not found.
198-
buffer_key =
199-
BufferKey(section->data_type(), ModelType::kTfLitePrefillDecode);
200-
}
201-
if (!backend_constraint.empty()) {
202-
section_backend_constraint_[buffer_key] = backend_constraint;
203-
ABSL_LOG(INFO) << "section_backend_constraint: " << backend_constraint;
204-
}
204+
ASSIGN_OR_RETURN(auto key_and_constraint,
205+
ExtractBufferKeyAndBackendConstraint(section));
206+
BufferKey buffer_key = key_and_constraint.first;
207+
if (key_and_constraint.second.has_value() &&
208+
!key_and_constraint.second->empty()) {
209+
section_backend_constraint_[buffer_key] = *key_and_constraint.second;
210+
ABSL_LOG(INFO) << "section_backend_constraint: "
211+
<< *key_and_constraint.second;
205212
}
206213
section_locations_[buffer_key] =
207214
std::make_pair(section->begin_offset(), section->end_offset());

runtime/util/litert_lm_loader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
namespace litert::lm {
4242

43+
inline constexpr uint64_t kLitertLmHeaderMaxSize = 16 * 1024;
44+
4345
// Each buffer is keyed by the data type as the major key and the model type
4446
// as the optional secondary key when the data type is TFLiteModel or
4547
// TFLiteWeights.
@@ -68,6 +70,10 @@ struct BufferKey {
6870
}
6971
};
7072

73+
// Extracts the BufferKey and backend constraint from the section metadata.
74+
absl::StatusOr<std::pair<BufferKey, std::optional<std::string>>>
75+
ExtractBufferKeyAndBackendConstraint(const schema::SectionObject* section);
76+
7177
// Hash function for BufferKey
7278
struct BufferKeyHash {
7379
size_t operator()(const BufferKey& k) const {

0 commit comments

Comments
 (0)