Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions runtime/engine/engine_settings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ absl::Status EngineSettings::MaybeUpdateAndValidate(
}

if (!metadata.has_llm_model_type()) {
const auto& model_assets = main_executor_settings_.GetModelAssets();
auto model_path = model_assets.GetPath();
if (tokenizer != nullptr) {
ASSIGN_OR_RETURN(*metadata.mutable_llm_model_type(),
InferLlmModelType(metadata, tokenizer));
Expand Down
3 changes: 3 additions & 0 deletions runtime/executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"//runtime/util:data_stream",
"//runtime/util:file_util",
"//runtime/util:litert_status_util",
"//runtime/util:memory_mapped_file",
Expand All @@ -96,11 +97,13 @@ cc_library(
cc_test(
name = "executor_settings_base_test",
srcs = ["executor_settings_base_test.cc"],
data = ["//runtime/testdata"],
deps = [
":executor_settings_base",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"//runtime/util:file_data_stream",
"//runtime/util:memory_mapped_file",
"//runtime/util:test_utils",
],
Expand Down
53 changes: 48 additions & 5 deletions runtime/executor/executor_settings_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "runtime/util/data_stream.h"
#include "runtime/util/file_util.h"
#include "runtime/util/memory_mapped_file.h"
#include "runtime/util/scoped_file.h"
Expand Down Expand Up @@ -152,6 +153,12 @@ absl::StatusOr<ModelAssets> ModelAssets::Create(
return ModelAssets(std::move(model_file));
}

// static
absl::StatusOr<ModelAssets> ModelAssets::Create(
std::shared_ptr<litert::lm::DataStream> data_stream) {
return ModelAssets(std::move(data_stream));
}

// static
absl::StatusOr<ModelAssets> ModelAssets::Create(
std::shared_ptr<litert::lm::ScopedFile> model_file,
Expand Down Expand Up @@ -182,6 +189,9 @@ ModelAssets::ModelAssets(
absl::string_view model_path)
: path_(model_path), memory_mapped_file_(std::move(model_file)) {}

ModelAssets::ModelAssets(std::shared_ptr<litert::lm::DataStream> data_stream)
: data_stream_(std::move(data_stream)) {}

absl::StatusOr<absl::string_view> ModelAssets::GetPath() const {
if (!path_.empty()) {
return path_;
Expand All @@ -206,6 +216,14 @@ ModelAssets::GetMemoryMappedFile() const {
return memory_mapped_file_;
}

absl::StatusOr<std::shared_ptr<DataStream>> ModelAssets::GetDataStream() const {
if (!HasDataStream()) {
return absl::InvalidArgumentError(
"Assets were not created with a data stream.");
}
return data_stream_;
}

absl::StatusOr<std::shared_ptr<ScopedFile>> ModelAssets::GetOrCreateScopedFile()
const {
if (HasScopedFile()) {
Expand All @@ -222,13 +240,38 @@ absl::StatusOr<std::shared_ptr<ScopedFile>> ModelAssets::GetOrCreateScopedFile()

std::ostream& operator<<(std::ostream& os, const ModelAssets& model_assets) {
if (model_assets.HasScopedFile()) {
os << "model_file file descriptor ID: "
<< model_assets.GetScopedFile().value()->file() << "\n";
auto scoped_file = model_assets.GetScopedFile();
if (scoped_file.ok()) {
os << "model_file file descriptor ID: " << scoped_file.value()->file()
<< "\n";
} else {
os << "model_file error getting ScopedFile: " << scoped_file.status()
<< "\n";
}
} else if (model_assets.HasMemoryMappedFile()) {
os << "model_file memory mapped file: "
<< model_assets.GetMemoryMappedFile().value()->data() << "\n";
auto memory_mapped_file = model_assets.GetMemoryMappedFile();
if (memory_mapped_file.ok()) {
os << "model_file memory mapped file: "
<< memory_mapped_file.value()->data() << "\n";
} else {
os << "model_file error getting MemoryMappedFile: "
<< memory_mapped_file.status() << "\n";
}
} else if (model_assets.HasDataStream()) {
auto data_stream = model_assets.GetDataStream();
if (data_stream.ok()) {
os << "model_file data stream: " << data_stream.value() << "\n";
} else {
os << "model_file error getting DataStream: " << data_stream.status()
<< "\n";
}
} else {
os << "model_path: " << model_assets.GetPath().value() << "\n";
auto model_path = model_assets.GetPath();
if (model_path.ok()) {
os << "model_path: " << model_path.value() << "\n";
} else {
os << "model_path error: " << model_path.status() << "\n";
}
}
os << "fake_weights_mode: " << model_assets.fake_weights_mode() << "\n";
return os;
Expand Down
7 changes: 7 additions & 0 deletions runtime/executor/executor_settings_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "runtime/util/data_stream.h"
#include "runtime/util/memory_mapped_file.h"
#include "runtime/util/scoped_file.h"

Expand Down Expand Up @@ -114,6 +115,8 @@ class ModelAssets {
static absl::StatusOr<ModelAssets> Create(
std::shared_ptr<MemoryMappedFile> model_file,
absl::string_view model_path);
static absl::StatusOr<ModelAssets> Create(
std::shared_ptr<DataStream> data_stream);

// Convenience factory function to create a ModelAssets with both a model
// path and file. Will use the scoped file if both are provided.
Expand All @@ -122,12 +125,14 @@ class ModelAssets {

bool HasScopedFile() const { return scoped_file_ != nullptr; }
bool HasMemoryMappedFile() const { return memory_mapped_file_ != nullptr; }
bool HasDataStream() const { return data_stream_ != nullptr; }

// Returns the model file if it was created with the respective variant,
// otherwise returns an error.
absl::StatusOr<absl::string_view> GetPath() const;
absl::StatusOr<std::shared_ptr<ScopedFile>> GetScopedFile() const;
absl::StatusOr<std::shared_ptr<MemoryMappedFile>> GetMemoryMappedFile() const;
absl::StatusOr<std::shared_ptr<DataStream>> GetDataStream() const;

// Convenience method to get a read-only scoped file to the model file
// regardless of whether this instance was created from a path or scoped file.
Expand All @@ -146,12 +151,14 @@ class ModelAssets {
explicit ModelAssets(std::shared_ptr<MemoryMappedFile> model_file);
explicit ModelAssets(std::shared_ptr<MemoryMappedFile> model_file,
absl::string_view model_path);
explicit ModelAssets(std::shared_ptr<DataStream> data_stream);

// TODO: b/417814685 - Consider supporting multiple model files if the need
// case arises.
std::string path_;
std::shared_ptr<ScopedFile> scoped_file_;
std::shared_ptr<MemoryMappedFile> memory_mapped_file_;
std::shared_ptr<DataStream> data_stream_;

FakeWeightsMode fake_weights_mode_ = FakeWeightsMode::FAKE_WEIGHTS_NONE;
};
Expand Down
25 changes: 25 additions & 0 deletions runtime/executor/executor_settings_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "runtime/executor/executor_settings_base.h"

#include <filesystem> // NOLINT
#include <memory>
#include <sstream>
#include <string>
Expand All @@ -24,12 +25,20 @@
#include <gtest/gtest.h>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "runtime/util/file_data_stream.h"
#include "runtime/util/memory_mapped_file.h"
#include "runtime/util/test_utils.h" // NOLINT

namespace litert::lm {
namespace {

std::string GetTestModelPath() {
const auto model_path =
std::filesystem::path(::testing::SrcDir()) /
"litert_lm/runtime/testdata/test_lm.litertlm";
return model_path.string();
}

TEST(LlmExecutorConfigTest, Backend) {
Backend backend;
std::stringstream oss;
Expand Down Expand Up @@ -175,6 +184,22 @@ TEST(LlmExecutorConfigTest, ModelAssetsMemoryMapped) {
EXPECT_THAT(oss.str(), testing::HasSubstr("FAKE_WEIGHTS_NONE"));
}

TEST(LlmExecutorConfigTest, ModelAssetsDataStream) {
ASSERT_OK_AND_ASSIGN(auto data_stream,
FileDataStream::Create(GetTestModelPath()));

auto model_assets = ModelAssets::Create(data_stream);
ASSERT_OK(model_assets);
EXPECT_TRUE(model_assets->HasDataStream());
ASSERT_OK_AND_ASSIGN(auto retrieved_stream, model_assets->GetDataStream());
EXPECT_EQ(retrieved_stream, data_stream);

std::stringstream oss;
oss << *model_assets;
EXPECT_THAT(oss.str(), testing::HasSubstr("model_file data stream"));
EXPECT_THAT(oss.str(), testing::HasSubstr("FAKE_WEIGHTS_NONE"));
}

class TestExecutorSettings : public ExecutorSettingsBase {
public:
explicit TestExecutorSettings(ModelAssets model_assets)
Expand Down
89 changes: 89 additions & 0 deletions runtime/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@ cc_test(
],
)

cc_library(
name = "data_stream",
srcs = ["data_stream.cc"],
hdrs = ["data_stream.h"],
deps = [
":litert_status_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

cc_test(
name = "data_stream_test",
srcs = ["data_stream_test.cc"],
data = ["//runtime/testdata"],
deps = [
":data_stream",
":file_data_stream",
":test_utils",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "log_tensor_buffer",
srcs = ["log_tensor_buffer.cc"],
Expand Down Expand Up @@ -154,6 +180,32 @@ cc_test(
],
)

cc_library(
name = "file_data_stream",
srcs = ["file_data_stream.cc"],
hdrs = ["file_data_stream.h"],
deps = [
":data_stream",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

cc_test(
name = "file_data_stream_test",
srcs = ["file_data_stream_test.cc"],
data = ["//runtime/testdata"],
deps = [
":file_data_stream",
":test_utils",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "file_format_util",
srcs = ["file_format_util.cc"],
Expand Down Expand Up @@ -472,6 +524,43 @@ cc_library(
],
)

cc_library(
name = "litert_lm_streaming_loader",
srcs = ["litert_lm_streaming_loader.cc"],
hdrs = ["litert_lm_streaming_loader.h"],
deps = [
":data_stream",
":litert_lm_loader",
":litert_status_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@litert//litert/cc:litert_buffer_ref",
"//runtime/components:model_resources",
"//schema/core:litertlm_header_schema",
"//schema/core:litertlm_read",
],
)

cc_test(
name = "litert_lm_streaming_loader_test",
srcs = ["litert_lm_streaming_loader_test.cc"],
data = ["//runtime/testdata"],
deps = [
":file_data_stream",
":litert_lm_streaming_loader",
":test_utils",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"//runtime/util:data_stream",
"//schema/core:litertlm_header_schema",
"//schema/core:litertlm_read",
],
)

cc_library(
name = "metadata_util",
srcs = ["metadata_util.cc"],
Expand Down
Loading
Loading