Skip to content

[executorch][flat_tensor] DataMap implementation #8280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extension/flat_tensor/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
257 changes: 257 additions & 0 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>

#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/core/freeable_buffer.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/platform/compiler.h>

using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::Result;
using executorch::runtime::Span;

using executorch::aten::ScalarType;
using executorch::runtime::DataLoader;
using executorch::runtime::TensorLayout;

namespace executorch {
namespace extension {

namespace {
/**
* FlatTensor data must be aligned to this value to properly parse it. Must be a
* power of 2. Note that max_align_t is the alignment that malloc() and new
* guarantee.
*/
constexpr size_t kMinimumAlignment = alignof(std::max_align_t);

bool is_aligned(const void* data) {
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
return addr % kMinimumAlignment == 0;
}

Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
const char* key,
const flatbuffers::Vector<
flatbuffers::Offset<flat_tensor_flatbuffer::TensorMetadata>>* tensors) {
// Linear search by name.
for (int i = 0; i < tensors->size(); i++) {
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
0) {
// TODO(T214294528): Support multiple segments in FlatTensor.
if (tensors->Get(i)->segment_index() != 0) {
return Error::InvalidExternalData;
}
return tensors->Get(i);
}
}
return Error::NotFound;
}

Result<const TensorLayout> create_tensor_layout(
const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) {
ScalarType scalar_type =
static_cast<ScalarType>(tensor_metadata->scalar_type());
const int dim = tensor_metadata->sizes()->size();
const auto serialized_sizes = tensor_metadata->sizes()->data();
const auto serialized_dim_order = tensor_metadata->dim_order()->data();
return TensorLayout::create(
Span<const int32_t>(serialized_sizes, dim),
Span<const uint8_t>(serialized_dim_order, dim),
scalar_type);
}

} // namespace

ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
const char* key) const {
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
get_flat_tensor_metadata(key, flat_tensor_->tensors());
if (!metadata_res.ok()) {
return metadata_res.error();
}
return create_tensor_layout(metadata_res.get());
}

ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
const char* key) const {
auto tensor_metadata = flat_tensor_->tensors();

Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
get_flat_tensor_metadata(key, tensor_metadata);
if (!metadata_res.ok()) {
return metadata_res.error();
}
const auto metadata = metadata_res.get();
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
// Invalid segment_index/offset; malformed PTD file.
return Error::InvalidExternalData;
}

Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
if (!tensor_layout_res.ok()) {
return tensor_layout_res.error();
}

// This FreeableBuffer doesn't own the underlying data, and will not free it,
// which is why the free function is a nullptr.
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
// FreeableBuffer own it.
return FreeableBuffer(
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
tensor_layout_res.get().nbytes(),
nullptr);
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const {
return Error::NotImplemented;
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
return flat_tensor_->tensors()->size();
}

ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
size_t index) const {
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
return Error::InvalidArgument;
}
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
}

/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
DataLoader* loader) {
// Load data map.
size_t flatbuffer_offset = 0;
size_t flatbuffer_size = 0;
size_t segment_base_offset = 0;
size_t segment_data_size = 0;
{
// Check header.
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!header.ok()) {
return header.error();
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.ok()) {
// The header has the data map size.
flatbuffer_offset = fh->flatbuffer_offset;
flatbuffer_size = fh->flatbuffer_size;
segment_base_offset = fh->segment_base_offset;
segment_data_size = fh->segment_data_size;
} else if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}
}

// Load flatbuffer data as a segment.
Result<FreeableBuffer> flat_tensor_data = loader->load(
/*offset=*/0,
flatbuffer_offset + flatbuffer_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!flat_tensor_data.ok()) {
return flat_tensor_data.error();
}

// Make sure magic matches.
if (!flat_tensor_flatbuffer::FlatTensorBufferHasIdentifier(
flat_tensor_data->data())) {
ET_LOG(
Error,
"FlatTensor identifier '%.4s' != expected '%.4s'",
flatbuffers::GetBufferIdentifier(flat_tensor_data->data()),
flat_tensor_flatbuffer::FlatTensorIdentifier());
return Error::InvalidExternalData;
}

// The flatbuffer data must start at an aligned address to ensure internal
// alignment of flatbuffer fields.
ET_CHECK_OR_RETURN_ERROR(
is_aligned(flat_tensor_data->data()),
InvalidArgument,
"FlatTensor data 0x%p must be aligned to %zu",
flat_tensor_data->data(),
kMinimumAlignment);

// Get pointer to root of flatbuffer table.
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());

// Validate flatbuffer data.
flatbuffers::Verifier verifier(
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
flat_tensor_data->size());
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
ET_CHECK_OR_RETURN_ERROR(
ok,
InvalidExternalData,
"Verification failed; data may be truncated or corrupt");

// Get pointer to tensor metadata.
const auto* s_tensor_metadata = flat_tensor->tensors();
if (s_tensor_metadata == nullptr) {
ET_LOG(Error, "FlatTensor has no tensor metadata.");
return Error::InvalidExternalData;
}

// Load constant data.
const auto* s_data_segment = flat_tensor->segments();

// TODO(T214294528): Support multiple segments in FlatTensor.
if (s_data_segment->size() != 1) {
ET_LOG(
Error,
"FlatTensor has %u segments, only 1 supported.",
s_data_segment->size());
}
// First segment size should be <= the total segment data size.
int segment_size = s_data_segment->Get(0)->size();
int segment_offset = s_data_segment->Get(0)->offset();
if (segment_size > segment_data_size) {
ET_LOG(
Error,
"FlatTensor segment size %d > segment data size %zu",
segment_size,
segment_data_size);
}

Result<FreeableBuffer> data_ro = loader->load(
/*offset=*/segment_base_offset + segment_offset,
segment_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!data_ro.ok()) {
return data_ro.error();
}

return FlatTensorDataMap(
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
}

} // namespace extension
} // namespace executorch
87 changes: 87 additions & 0 deletions extension/flat_tensor/flat_tensor_data_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/named_data_map.h>

#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/tensor_layout.h>
#include <executorch/runtime/platform/compiler.h>

#include <utility>

// Forward declare flatbuffer types. This is a public header and must not
// include the generated flatbuffer header.
namespace flat_tensor_flatbuffer {
struct FlatTensor;
} // namespace flat_tensor_flatbuffer

namespace executorch {
namespace extension {

/**
* A NamedDataMap implementation for FlatTensor-serialized data.
*/
class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
public:
/**
* Creates a new DataMap that wraps FlatTensor data.
*
* @param[in] loader The DataLoader that wraps the FlatTensor file.
* Note: the loader must outlive the FlatTensorDataMap instance.
*/
static executorch::runtime::Result<FlatTensorDataMap> load(
executorch::runtime::DataLoader* loader);

ET_NODISCARD
executorch::runtime::Result<const executorch::runtime::TensorLayout>
get_metadata(const char* key) const override;
ET_NODISCARD
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
const char* key) const override;
ET_NODISCARD executorch::runtime::Result<size_t>
load_data_into(const char* key, void* buffer, size_t size) const override;

ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
const override;
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
size_t index) const override;

FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;

~FlatTensorDataMap() override = default;

private:
FlatTensorDataMap(
executorch::runtime::FreeableBuffer&& flat_tensor_data,
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
executorch::runtime::FreeableBuffer&& data_ro)
: flat_tensor_data_(std::move(flat_tensor_data)),
flat_tensor_(flat_tensor),
data_ro_(std::move(data_ro)) {}

// Not copyable or assignable.
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;

// Serialized flat_tensor flatbuffer data.
executorch::runtime::FreeableBuffer flat_tensor_data_;

// Flatbuffer representation of the flat_tensor.
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;

// Loaded read-only tensor data.
executorch::runtime::FreeableBuffer data_ro_;
};

} // namespace extension
} // namespace executorch
22 changes: 22 additions & 0 deletions extension/flat_tensor/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
runtime.cxx_library(
name = "flat_tensor_data_map",
srcs = [
"flat_tensor_data_map.cpp",
],
exported_headers = ["flat_tensor_data_map.h"],
deps = [
"//executorch/extension/flat_tensor/serialize:generated_headers",
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
"//executorch/runtime/core:core",
"//executorch/runtime/core:evalue",
"//executorch/runtime/core:named_data_map",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
visibility = [
"//executorch/...",
],
)
2 changes: 1 addition & 1 deletion extension/flat_tensor/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
define_common_targets(is_fbcode=True)

python_unittest(
name = "serialize",
Expand Down
Loading
Loading