diff --git a/extension/testing_util/targets.bzl b/extension/testing_util/targets.bzl index a04ffb90c9f..2b12480dfff 100644 --- a/extension/testing_util/targets.bzl +++ b/extension/testing_util/targets.bzl @@ -16,5 +16,6 @@ def define_common_targets(): "//executorch/extension/testing_util/test/...", "//executorch/extension/fb/ptez/decompression_methods/test/...", "//executorch/extension/fb/ptez/test/...", + "//executorch/runtime/executor/test/...", ], ) diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index 67f1edd4df3..404542151f2 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -150,6 +150,22 @@ Result get_execution_plan( const executorch_flatbuffer::Program* flatbuffer_program = executorch_flatbuffer::GetProgram(program_data->data()); + // Instantiate PteDataMap if named_data is present. + const auto named_data = flatbuffer_program->named_data(); + std::optional pte_data_map = std::nullopt; + if (named_data != nullptr) { + Result pte_data_map_result = + internal::PteDataMap::create( + loader, + segment_base_offset, + named_data, + flatbuffer_program->segments()); + if (!pte_data_map_result.ok()) { + return pte_data_map_result.error(); + } + pte_data_map.emplace(std::move(pte_data_map_result.get())); + } + // Constant data may live inside the flatbuffer data (constant_buffer) or in a // separate segment (constant_segment). It should not be in both. // Check constant_segment->offsets()->size() > 1, as the offsets list will @@ -199,7 +215,8 @@ Result get_execution_plan( segment_base_offset, std::move(program_data.get()), flatbuffer_program, - std::move(constant_segment_data.get())); + std::move(constant_segment_data.get()), + std::move(pte_data_map)); } else { // The constant data is stored inside the flatbuffer, so this program does // not contain a separate segment for it. @@ -208,7 +225,8 @@ Result get_execution_plan( segment_base_offset, std::move(program_data.get()), flatbuffer_program, - /*constant_segment_data=*/FreeableBuffer{}); + /*constant_segment_data=*/FreeableBuffer{}, + std::move(pte_data_map)); } } diff --git a/runtime/executor/program.h b/runtime/executor/program.h index 7313b19d66d..8bcb5fe4d97 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -22,6 +23,7 @@ #include #include #include +#include #include // Forward declare flatbuffer types. This is a public header and must not @@ -266,13 +268,15 @@ class Program final { size_t segment_base_offset, FreeableBuffer&& program_data, const executorch_flatbuffer::Program* internal_program, - FreeableBuffer&& constant_segment_data) + FreeableBuffer&& constant_segment_data, + std::optional&& pte_data_map) : program_data_(std::move(program_data)), // Don't need the loader if there are no segments. loader_(segment_base_offset > 0 ? loader : nullptr), internal_program_(internal_program), segment_base_offset_(segment_base_offset), - constant_segment_data_(std::move(constant_segment_data)) {} + constant_segment_data_(std::move(constant_segment_data)), + pte_data_map_(std::move(pte_data_map)) {} // Not copyable or assignable. Program(const Program& rhs) = delete; @@ -295,6 +299,9 @@ class Program final { /// Constant segment data. FreeableBuffer constant_segment_data_; + + /// NamedDataMap holding named data from the program. + std::optional pte_data_map_; }; } // namespace runtime diff --git a/runtime/executor/pte_data_map.cpp b/runtime/executor/pte_data_map.cpp new file mode 100644 index 00000000000..5829395028a --- /dev/null +++ b/runtime/executor/pte_data_map.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace runtime { +namespace internal { + +/* static */ executorch::runtime::Result PteDataMap::create( + executorch::runtime::DataLoader* loader, + size_t segment_base_offset, + const flatbuffers::FlatbufferNamedData* named_data, + const flatbuffers::FlatbufferDataSegment* segments) { + ET_CHECK_OR_RETURN_ERROR( + loader != nullptr && named_data != nullptr && segments != nullptr, + InvalidArgument, + "PteDataMap loader, named_data or segments is null; most likely the program does not have any named_data segments"); + return PteDataMap(loader, segment_base_offset, named_data, segments); +} + +ET_NODISCARD +executorch::runtime::Result +PteDataMap::get_data(const char* key) const { + for (size_t i = 0; i < named_data_->size(); i++) { + ET_CHECK_OR_RETURN_ERROR( + named_data_->Get(i) != nullptr && named_data_->Get(i)->key() != nullptr, + InvalidArgument, + "Searching for key %s: NamedData at index %zu is null", + key, + i); + if (strcmp(named_data_->Get(i)->key()->c_str(), key) == 0) { + // Get the segment index. + size_t segment_index = named_data_->Get(i)->segment_index(); + + // Get the segment offset and size. + ET_CHECK_OR_RETURN_ERROR( + segment_index < segments_->size(), + InvalidArgument, + "Segment index %zu for key %s is out of range for segments size %u", + segment_index, + key, + segments_->size()); + size_t segment_offset = segments_->Get(segment_index)->offset(); + size_t segment_size = segments_->Get(segment_index)->size(); + + return loader_->load( + /*offset=*/segment_base_offset_ + segment_offset, + segment_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + } + } + return Error::NotFound; +} + +ET_NODISCARD executorch::runtime::Result PteDataMap::get_num_keys() + const { + return named_data_->size(); +} + +ET_NODISCARD executorch::runtime::Result PteDataMap::get_key( + size_t index) const { + ET_CHECK_OR_RETURN_ERROR( + index < named_data_->size(), + InvalidArgument, + "Index out of range: named_data size is %u, received index %zu", + named_data_->size(), + index); + + ET_CHECK_OR_RETURN_ERROR( + named_data_->Get(index) != nullptr && + named_data_->Get(index)->key() != nullptr, + InvalidArgument, + "NamedData at index %zu is null", + index); + return named_data_->Get(index)->key()->c_str(); +} + +} // namespace internal +} // namespace runtime +} // namespace executorch diff --git a/runtime/executor/pte_data_map.h b/runtime/executor/pte_data_map.h new file mode 100644 index 00000000000..01c15555786 --- /dev/null +++ b/runtime/executor/pte_data_map.h @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +// Forward declare flatbuffer types. This is a public header and must not +// include the generated flatbuffer header. +namespace executorch_flatbuffer { +struct NamedData; +struct DataSegment; +} // namespace executorch_flatbuffer + +namespace flatbuffers { +template +struct Offset; +} // namespace flatbuffers + +// @lint-ignore CLANGTIDY facebook-modularize-issue-check +#if EXECUTORCH_INTERNAL_FLATBUFFERS == 1 +// TODO(T216992074): update internal flatbuffers (v1.12) to match OSS (v24.3.5). +namespace flatbuffers { +template +class Vector; +using FlatbufferNamedData = + flatbuffers::Vector>; +using FlatbufferDataSegment = flatbuffers::Vector< + flatbuffers::Offset>; +} // namespace flatbuffers +#else +namespace flatbuffers { +template +class Vector; +using FlatbufferNamedData = flatbuffers:: + Vector, uint32_t>; +using FlatbufferDataSegment = flatbuffers:: + Vector, uint32_t>; +} // namespace flatbuffers +#endif + +namespace executorch { +namespace runtime { +namespace internal { + +/** + * A NamedDataMap implementation for Flatbuffer-serialized named data + * originating from a PTE file. + */ +class PteDataMap final : public NamedDataMap { + public: + /** + * Creates a new DataMap that wraps named_data from the PTE file. + * + * @param[in] loader The DataLoader that accesses the PTE file. + * Note: the loader must outlive the PteDataMap instance. + * @param[in] segment_base_offset The offset to the first segment in the PTE + * file, in bytes. + * @param[in] named_data The named_data from the PTE file. Note: the pointer + * passed here must outlive the PteDataMap instance. + * @param[in] segments The segments from the PTE file. Note: the pointer + * passed here must outlive the PteDataMap instance. + */ + static Result create( + DataLoader* loader, + size_t segment_base_offset, + const flatbuffers::FlatbufferNamedData* named_data, + const flatbuffers::FlatbufferDataSegment* segments); + + /** + * The PteDataMap currently only handles opaque data that does not contain + * tensor-specific metadata. + */ + ET_NODISCARD + Result get_metadata( + ET_UNUSED const char* key) const override { + return Error::NotImplemented; + } + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the blob to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ + ET_NODISCARD + Result get_data(const char* key) const override; + + /** + * The PteDataMap currently does not implement load_into. + */ + ET_NODISCARD Error load_data_into( + ET_UNUSED const char* key, + ET_UNUSED void* buffer, + ET_UNUSED size_t size) const override { + return Error::NotImplemented; + } + + /** + * @returns The number of keys in the map. + */ + ET_NODISCARD Result get_num_keys() const override; + + /** + * @returns The key at the specified index, error if index out of bounds. + */ + ET_NODISCARD Result get_key(size_t index) const override; + + // Moveable, to be compatible with Result. + PteDataMap(PteDataMap&&) noexcept = default; + ~PteDataMap() override = default; + + private: + PteDataMap( + DataLoader* loader, + size_t segment_base_offset, + const flatbuffers::FlatbufferNamedData* named_data, + const flatbuffers::FlatbufferDataSegment* segments) + : loader_(loader), + segment_base_offset_(segment_base_offset), + named_data_(named_data), + segments_(segments) {} + + // Not copyable or assignable. + PteDataMap(const PteDataMap& rhs) = delete; + PteDataMap& operator=(PteDataMap&& rhs) noexcept = delete; + PteDataMap& operator=(const PteDataMap& rhs) = delete; + + // Data loader, used to load segment data. + DataLoader* loader_; + + // The offset to the first segment in the PTE file, in bytes. + size_t segment_base_offset_; + + // Named data, containing name and segment index. + const flatbuffers::FlatbufferNamedData* named_data_; + + // Segments, to retrieve offset and size for the loader. + const flatbuffers::FlatbufferDataSegment* segments_; +}; + +} // namespace internal +} // namespace runtime +} // namespace executorch diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 8993c5dc473..f2d41105660 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -42,6 +42,26 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "pte_data_map", + srcs = [ + "pte_data_map.cpp", + ], + exported_headers = [ + "pte_data_map.h", + ], + visibility = [ + "//executorch/runtime/executor/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:named_data_map", + "//executorch/schema:program", + ], + exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"], + ) + for aten_mode in get_aten_mode_options(): aten_suffix = "_aten" if aten_mode else "" runtime.cxx_library( @@ -81,6 +101,7 @@ def define_common_targets(): preprocessor_flags = _program_preprocessor_flags(), exported_deps = [ ":memory_manager", + ":pte_data_map", "//executorch/runtime/backend:interface", "//executorch/runtime/core:core", "//executorch/runtime/core:named_data_map", diff --git a/runtime/executor/test/pte_data_map_test.cpp b/runtime/executor/test/pte_data_map_test.cpp new file mode 100644 index 00000000000..b5312eb4a88 --- /dev/null +++ b/runtime/executor/test/pte_data_map_test.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FileDataLoader; +using executorch::extension::testing::TempFile; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; +using executorch::runtime::internal::PteDataMap; + +class PteDataMapTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Create a sample Program with only named_data and segments. Technically + // not a valid Program; only used to test the PteDataMap. + // Create named data. + std::array, 4> + named_data_arr = { + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "key0", /*segment_index=*/0), + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "key1", /*segment_index=*/1), + // Note: key2 points to the same segment as key0. + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "key2", /*segment_index=*/0), + // This is invalid, as segment_index=10 is out of range when the + // number of segments is 2. + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "key_invalid", /*segment_index=*/10), + }; + const auto named_data = + builder_.CreateVector(named_data_arr.data(), named_data_arr.size()); + + // Create segments. + std::array, 2> + segment_arr = {// @lint-ignore CLANGTIDY facebook-hte-BadArgumentComment + executorch_flatbuffer::CreateDataSegment( + builder_, /*offset=*/0, /*size=*/kSegmentSizes[0]), + // @lint-ignore CLANGTIDY facebook-hte-BadArgumentComment + executorch_flatbuffer::CreateDataSegment( + builder_, + /*offset=*/kSegmentAlignment * 2, + /*size=*/kSegmentSizes[1])}; + const auto segments = + builder_.CreateVector(segment_arr.data(), segment_arr.size()); + + // Create Program. + const auto program = executorch_flatbuffer::CreateProgram( + builder_, 0, 0, 0, 0, segments, 0, 0, named_data); + + builder_.Finish(program); + program_ = executorch_flatbuffer::GetProgram(builder_.GetBufferPointer()); + + // Create sample segment data. + for (int i = 0; i < kSegmentSizes[0]; i++) { + sample_data_[i] = 1; + } + for (int i = kSegmentOffsets[1]; i < kSegmentOffsets[1] + kSegmentSizes[1]; + i++) { + sample_data_[i] = 2; + } + TempFile tf(sample_data_.data(), sizeof(sample_data_)); + + // Wrap the sample data in a loader. + Result loader = + FileDataLoader::from(tf.path().c_str(), kSegmentAlignment); + ASSERT_EQ(loader.error(), Error::Ok); + data_map_loader_ = + std::make_unique(std::move(loader.get())); + } + + // Program builder constants. + static constexpr int kSegmentAlignment = 16; + static constexpr std::array kSegmentSizes{17, 8}; + static constexpr std::array kSegmentOffsets{0, kSegmentAlignment * 2}; + std::array sample_data_; + + // Program builder. + flatbuffers::FlatBufferBuilder builder_; + const executorch_flatbuffer::Program* program_; + + // Data loader for the sample data. + std::unique_ptr data_map_loader_; +}; + +TEST_F(PteDataMapTest, Load) { + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ASSERT_TRUE(data_map.ok()); +} + +TEST_F(PteDataMapTest, LoadFail) { + Result data_map = PteDataMap::create( + /*loader=*/nullptr, + /*segment_base_offset=*/0, + program_->named_data(), + program_->segments()); + EXPECT_EQ(data_map.error(), Error::InvalidArgument); +} + +TEST_F(PteDataMapTest, UnimplementedMethods) { + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ; + + // Check get_metadata is not implemented. + auto result = data_map->get_metadata("sample_key"); + EXPECT_EQ(result.error(), Error::NotImplemented); + + // Check load_data_into is not implemented. + auto err = data_map->load_data_into("sample_key", nullptr, 0); + EXPECT_EQ(err, Error::NotImplemented); +} + +TEST_F(PteDataMapTest, Keys) { + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ASSERT_TRUE(data_map.ok()); + + // Check get_num_keys. + auto num_keys = data_map->get_num_keys(); + EXPECT_EQ(num_keys.error(), Error::Ok); + EXPECT_EQ(num_keys.get(), 4); + + // Check get_key_at. + auto key0 = data_map->get_key(0); + EXPECT_EQ(strcmp(key0.get(), "key0"), 0); + auto key1 = data_map->get_key(1); + EXPECT_EQ(strcmp(key1.get(), "key1"), 0); + auto key2 = data_map->get_key(2); + EXPECT_EQ(strcmp(key2.get(), "key2"), 0); + + // This key is invalid because it points to a segment_index=10, which is out + // of range for this example with segment size=2. + // Note: practically, a PTE should not have invalid keys. + auto key_invalid = data_map->get_key(3); + EXPECT_EQ(strcmp(key_invalid.get(), "key_invalid"), 0); + + // Returns an error on non-existent key. + auto nonexistent_key = data_map->get_key(10); + EXPECT_EQ(nonexistent_key.error(), Error::InvalidArgument); +} + +TEST_F(PteDataMapTest, GetData) { + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ASSERT_TRUE(data_map.ok()); + + Result data0 = data_map->get_data("key0"); + EXPECT_EQ(data0.error(), Error::Ok); + EXPECT_EQ(data0.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp(data0.get().data(), sample_data_.data(), data0.get().size()), 0); + + Result data1 = data_map->get_data("key1"); + EXPECT_EQ(data1.error(), Error::Ok); + EXPECT_EQ(data1.get().size(), kSegmentSizes[1]); + EXPECT_EQ( + memcmp( + data1.get().data(), + sample_data_.data() + kSegmentOffsets[1], + data1.get().size()), + 0); + + Result data2 = data_map->get_data("key2"); + EXPECT_EQ(data2.error(), Error::Ok); + // Expect the same values as data0, as key0 and key2 point to the same + // segment. + EXPECT_EQ(data2.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp(data2.get().data(), sample_data_.data(), data2.get().size()), 0); + + // Free data. + data0->Free(); + data1->Free(); + data2->Free(); + + // Returns an error, as key_invalid contains segment_index=10, which + // is out of range for segments.size()=2. + Result data_invalid = data_map->get_data("key_invalid"); + EXPECT_EQ(data_invalid.error(), Error::InvalidArgument); + + // Returns an error on nonexistent key. + Result data_nonexistent = + data_map->get_data("nonexistent_key"); + EXPECT_EQ(data_nonexistent.error(), Error::NotFound); +} + +TEST_F(PteDataMapTest, FreeAndReload) { + // Load a key, free it, and then load it again, and ensure that the + // core data map can return a new FreeableBuffer with the same data. + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ASSERT_TRUE(data_map.ok()); + + // Load data0. + Result data0 = data_map->get_data("key0"); + EXPECT_EQ(data0.error(), Error::Ok); + EXPECT_EQ(data0.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp(data0.get().data(), sample_data_.data(), data0.get().size()), 0); + data0->Free(); + + // Reload data0, ensure that the core data map can return a new + // FreeableBuffer with the same data. + Result data0_reload = data_map->get_data("key0"); + EXPECT_EQ(data0_reload.error(), Error::Ok); + EXPECT_EQ(data0_reload.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp( + data0_reload.get().data(), + sample_data_.data(), + data0_reload.get().size()), + 0); + data0_reload->Free(); +} + +TEST_F(PteDataMapTest, ReloadAndFree) { + // Load the same key multiple times, and then free one and ensure that the + // data in the other is still valid. + Result data_map = PteDataMap::create( + data_map_loader_.get(), 0, program_->named_data(), program_->segments()); + ASSERT_TRUE(data_map.ok()); + + // Load data0. + Result data0 = data_map->get_data("key0"); + EXPECT_EQ(data0.error(), Error::Ok); + EXPECT_EQ(data0.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp(data0.get().data(), sample_data_.data(), data0.get().size()), 0); + + // Reload data0. + Result data0_reload = data_map->get_data("key0"); + EXPECT_EQ(data0_reload.error(), Error::Ok); + EXPECT_EQ(data0_reload.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp( + data0_reload.get().data(), + sample_data_.data(), + data0_reload.get().size()), + 0); + + // Free data0 and check that data0_reload is still valid. + data0->Free(); + EXPECT_EQ(data0_reload.get().size(), kSegmentSizes[0]); + EXPECT_EQ( + memcmp( + data0_reload.get().data(), + sample_data_.data(), + data0_reload.get().size()), + 0); + + // Free data_reload0. + data0_reload->Free(); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 1dbb4ea6108..98c19445f54 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -92,6 +92,18 @@ def define_common_targets(is_fbcode = False): ], ) + runtime.cxx_test( + name = "pte_data_map_test", + srcs = [ + "pte_data_map_test.cpp", + ], + deps = [ + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/testing_util:temp_file", + "//executorch/runtime/executor:pte_data_map", + ], + ) + # TODO(dbort): Find a way to make these run for ANDROID/APPLE in xplat. The # android and ios test determinators don't like the reference to the model # file in fbcode. See https://fburl.com/9esapdmd