Skip to content

[executorch][runtime] Introduce PteDataMap for weight sharing #8887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 5, 2025
1 change: 1 addition & 0 deletions extension/testing_util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ def define_common_targets():
"//executorch/extension/testing_util/test/...",
"//executorch/extension/fb/ptez/decompression_methods/test/...",
"//executorch/extension/fb/ptez/test/...",
"//executorch/runtime/executor/test/...",
],
)
22 changes: 20 additions & 2 deletions runtime/executor/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,22 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
const executorch_flatbuffer::Program* flatbuffer_program =
executorch_flatbuffer::GetProgram(program_data->data());

// Instantiate PteDataMap if named_data is present.
const auto named_data = flatbuffer_program->named_data();
std::optional<internal::PteDataMap> pte_data_map = std::nullopt;
if (named_data != nullptr) {
Result<internal::PteDataMap> pte_data_map_result =
internal::PteDataMap::create(
loader,
segment_base_offset,
named_data,
flatbuffer_program->segments());
if (!pte_data_map_result.ok()) {
return pte_data_map_result.error();
}
pte_data_map.emplace(std::move(pte_data_map_result.get()));
}

// Constant data may live inside the flatbuffer data (constant_buffer) or in a
// separate segment (constant_segment). It should not be in both.
// Check constant_segment->offsets()->size() > 1, as the offsets list will
Expand Down Expand Up @@ -199,7 +215,8 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
segment_base_offset,
std::move(program_data.get()),
flatbuffer_program,
std::move(constant_segment_data.get()));
std::move(constant_segment_data.get()),
std::move(pte_data_map));
} else {
// The constant data is stored inside the flatbuffer, so this program does
// not contain a separate segment for it.
Expand All @@ -208,7 +225,8 @@ Result<executorch_flatbuffer::ExecutionPlan*> get_execution_plan(
segment_base_offset,
std::move(program_data.get()),
flatbuffer_program,
/*constant_segment_data=*/FreeableBuffer{});
/*constant_segment_data=*/FreeableBuffer{},
std::move(pte_data_map));
}
}

Expand Down
11 changes: 9 additions & 2 deletions runtime/executor/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cinttypes>
#include <cstdint>
#include <optional>

#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -22,6 +23,7 @@
#include <executorch/runtime/executor/memory_manager.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/method_meta.h>
#include <executorch/runtime/executor/pte_data_map.h>
#include <executorch/runtime/platform/compiler.h>

// Forward declare flatbuffer types. This is a public header and must not
Expand Down Expand Up @@ -266,13 +268,15 @@ class Program final {
size_t segment_base_offset,
FreeableBuffer&& program_data,
const executorch_flatbuffer::Program* internal_program,
FreeableBuffer&& constant_segment_data)
FreeableBuffer&& constant_segment_data,
std::optional<internal::PteDataMap>&& pte_data_map)
: program_data_(std::move(program_data)),
// Don't need the loader if there are no segments.
loader_(segment_base_offset > 0 ? loader : nullptr),
internal_program_(internal_program),
segment_base_offset_(segment_base_offset),
constant_segment_data_(std::move(constant_segment_data)) {}
constant_segment_data_(std::move(constant_segment_data)),
pte_data_map_(std::move(pte_data_map)) {}

// Not copyable or assignable.
Program(const Program& rhs) = delete;
Expand All @@ -295,6 +299,9 @@ class Program final {

/// Constant segment data.
FreeableBuffer constant_segment_data_;

/// NamedDataMap holding named data from the program.
std::optional<internal::PteDataMap> pte_data_map_;
};

} // namespace runtime
Expand Down
87 changes: 87 additions & 0 deletions runtime/executor/pte_data_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/executor/pte_data_map.h>
#include <executorch/schema/program_generated.h>

namespace executorch {
namespace runtime {
namespace internal {

/* static */ executorch::runtime::Result<PteDataMap> PteDataMap::create(
executorch::runtime::DataLoader* loader,
size_t segment_base_offset,
const flatbuffers::FlatbufferNamedData* named_data,
const flatbuffers::FlatbufferDataSegment* segments) {
ET_CHECK_OR_RETURN_ERROR(
loader != nullptr && named_data != nullptr && segments != nullptr,
InvalidArgument,
"PteDataMap loader, named_data or segments is null; most likely the program does not have any named_data segments");
return PteDataMap(loader, segment_base_offset, named_data, segments);
}

ET_NODISCARD
executorch::runtime::Result<executorch::runtime::FreeableBuffer>
PteDataMap::get_data(const char* key) const {
for (size_t i = 0; i < named_data_->size(); i++) {
ET_CHECK_OR_RETURN_ERROR(
named_data_->Get(i) != nullptr && named_data_->Get(i)->key() != nullptr,
InvalidArgument,
"Searching for key %s: NamedData at index %zu is null",
key,
i);
if (strcmp(named_data_->Get(i)->key()->c_str(), key) == 0) {
// Get the segment index.
size_t segment_index = named_data_->Get(i)->segment_index();

// Get the segment offset and size.
ET_CHECK_OR_RETURN_ERROR(
segment_index < segments_->size(),
InvalidArgument,
"Segment index %zu for key %s is out of range for segments size %u",
segment_index,
key,
segments_->size());
size_t segment_offset = segments_->Get(segment_index)->offset();
size_t segment_size = segments_->Get(segment_index)->size();

return loader_->load(
/*offset=*/segment_base_offset_ + segment_offset,
segment_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
}
}
return Error::NotFound;
}

ET_NODISCARD executorch::runtime::Result<size_t> PteDataMap::get_num_keys()
const {
return named_data_->size();
}

ET_NODISCARD executorch::runtime::Result<const char*> PteDataMap::get_key(
size_t index) const {
ET_CHECK_OR_RETURN_ERROR(
index < named_data_->size(),
InvalidArgument,
"Index out of range: named_data size is %u, received index %zu",
named_data_->size(),
index);

ET_CHECK_OR_RETURN_ERROR(
named_data_->Get(index) != nullptr &&
named_data_->Get(index)->key() != nullptr,
InvalidArgument,
"NamedData at index %zu is null",
index);
return named_data_->Get(index)->key()->c_str();
}

} // namespace internal
} // namespace runtime
} // namespace executorch
151 changes: 151 additions & 0 deletions runtime/executor/pte_data_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/named_data_map.h>

// Forward declare flatbuffer types. This is a public header and must not
// include the generated flatbuffer header.
namespace executorch_flatbuffer {
struct NamedData;
struct DataSegment;
} // namespace executorch_flatbuffer

namespace flatbuffers {
template <typename T>
struct Offset;
} // namespace flatbuffers

// @lint-ignore CLANGTIDY facebook-modularize-issue-check
#if EXECUTORCH_INTERNAL_FLATBUFFERS == 1
// TODO(T216992074): update internal flatbuffers (v1.12) to match OSS (v24.3.5).
namespace flatbuffers {
template <typename T>
class Vector;
using FlatbufferNamedData =
flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::NamedData>>;
using FlatbufferDataSegment = flatbuffers::Vector<
flatbuffers::Offset<executorch_flatbuffer::DataSegment>>;
} // namespace flatbuffers
#else
namespace flatbuffers {
template <typename T, typename SizeT>
class Vector;
using FlatbufferNamedData = flatbuffers::
Vector<flatbuffers::Offset<executorch_flatbuffer::NamedData>, uint32_t>;
using FlatbufferDataSegment = flatbuffers::
Vector<flatbuffers::Offset<executorch_flatbuffer::DataSegment>, uint32_t>;
} // namespace flatbuffers
#endif

namespace executorch {
namespace runtime {
namespace internal {

/**
* A NamedDataMap implementation for Flatbuffer-serialized named data
* originating from a PTE file.
*/
class PteDataMap final : public NamedDataMap {
public:
/**
* Creates a new DataMap that wraps named_data from the PTE file.
*
* @param[in] loader The DataLoader that accesses the PTE file.
* Note: the loader must outlive the PteDataMap instance.
* @param[in] segment_base_offset The offset to the first segment in the PTE
* file, in bytes.
* @param[in] named_data The named_data from the PTE file. Note: the pointer
* passed here must outlive the PteDataMap instance.
* @param[in] segments The segments from the PTE file. Note: the pointer
* passed here must outlive the PteDataMap instance.
*/
static Result<PteDataMap> create(
DataLoader* loader,
size_t segment_base_offset,
const flatbuffers::FlatbufferNamedData* named_data,
const flatbuffers::FlatbufferDataSegment* segments);

/**
* The PteDataMap currently only handles opaque data that does not contain
* tensor-specific metadata.
*/
ET_NODISCARD
Result<const TensorLayout> get_metadata(
ET_UNUSED const char* key) const override {
return Error::NotImplemented;
}

/**
* Retrieve read-only data for the specified key.
*
* @param[in] key The name of the blob to get data on.
*
* @return error if the key is not present or data cannot be loaded.
*/
ET_NODISCARD
Result<FreeableBuffer> get_data(const char* key) const override;

/**
* The PteDataMap currently does not implement load_into.
*/
ET_NODISCARD Error load_data_into(
ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const override {
return Error::NotImplemented;
}

/**
* @returns The number of keys in the map.
*/
ET_NODISCARD Result<size_t> get_num_keys() const override;

/**
* @returns The key at the specified index, error if index out of bounds.
*/
ET_NODISCARD Result<const char*> get_key(size_t index) const override;

// Moveable, to be compatible with Result.
PteDataMap(PteDataMap&&) noexcept = default;
~PteDataMap() override = default;

private:
PteDataMap(
DataLoader* loader,
size_t segment_base_offset,
const flatbuffers::FlatbufferNamedData* named_data,
const flatbuffers::FlatbufferDataSegment* segments)
: loader_(loader),
segment_base_offset_(segment_base_offset),
named_data_(named_data),
segments_(segments) {}

// Not copyable or assignable.
PteDataMap(const PteDataMap& rhs) = delete;
PteDataMap& operator=(PteDataMap&& rhs) noexcept = delete;
PteDataMap& operator=(const PteDataMap& rhs) = delete;

// Data loader, used to load segment data.
DataLoader* loader_;

// The offset to the first segment in the PTE file, in bytes.
size_t segment_base_offset_;

// Named data, containing name and segment index.
const flatbuffers::FlatbufferNamedData* named_data_;

// Segments, to retrieve offset and size for the loader.
const flatbuffers::FlatbufferDataSegment* segments_;
};

} // namespace internal
} // namespace runtime
} // namespace executorch
21 changes: 21 additions & 0 deletions runtime/executor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ def define_common_targets():
],
)

runtime.cxx_library(
name = "pte_data_map",
srcs = [
"pte_data_map.cpp",
],
exported_headers = [
"pte_data_map.h",
],
visibility = [
"//executorch/runtime/executor/...",
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/core:named_data_map",
"//executorch/schema:program",
],
exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"],
)

for aten_mode in get_aten_mode_options():
aten_suffix = "_aten" if aten_mode else ""
runtime.cxx_library(
Expand Down Expand Up @@ -81,6 +101,7 @@ def define_common_targets():
preprocessor_flags = _program_preprocessor_flags(),
exported_deps = [
":memory_manager",
":pte_data_map",
"//executorch/runtime/backend:interface",
"//executorch/runtime/core:core",
"//executorch/runtime/core:named_data_map",
Expand Down
Loading
Loading