Skip to content

Commit 3d0d2f7

Browse files
committed
[executorch][flat_tensor] DataMap implementation
Pull Request resolved: #7900 DataMap implementation that * Loads a flat_tensor file * Makes tensor information available via the named_data_map.h interface. TODO: in a later diff, update the ET runtime to hold onto the FreeableBuffers returned by the NDM. Then, the NDM will not persist the segment. T214294528 ghstack-source-id: 264905837 Differential Revision: [D67064580](https://our.internmc.facebook.com/intern/diff/D67064580/)
1 parent 15c8bdf commit 3d0d2f7

File tree

8 files changed

+543
-2
lines changed

8 files changed

+543
-2
lines changed

extension/flat_tensor/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
10+
11+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
12+
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>
13+
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
16+
#include <executorch/runtime/core/freeable_buffer.h>
17+
#include <executorch/runtime/core/result.h>
18+
#include <executorch/runtime/core/span.h>
19+
#include <executorch/runtime/platform/compiler.h>
20+
21+
using executorch::runtime::Error;
22+
using executorch::runtime::FreeableBuffer;
23+
using executorch::runtime::Result;
24+
using executorch::runtime::Span;
25+
26+
using executorch::aten::ScalarType;
27+
using executorch::runtime::DataLoader;
28+
using executorch::runtime::TensorLayout;
29+
30+
namespace executorch {
31+
namespace extension {
32+
33+
namespace {
34+
/**
35+
* FlatTensor data must be aligned to this value to properly parse it. Must be a
36+
* power of 2. Note that max_align_t is the alignment that malloc() and new
37+
* guarantee.
38+
*/
39+
constexpr size_t kMinimumAlignment = alignof(std::max_align_t);
40+
41+
bool is_aligned(const void* data) {
42+
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
43+
return addr % kMinimumAlignment == 0;
44+
}
45+
46+
Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
47+
const char* key,
48+
const flatbuffers::Vector<
49+
flatbuffers::Offset<flat_tensor_flatbuffer::TensorMetadata>>* tensors) {
50+
// Linear search by name.
51+
for (int i = 0; i < tensors->size(); i++) {
52+
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
53+
0) {
54+
// TODO(T214294528): Support multiple segments in FlatTensor.
55+
if (tensors->Get(i)->segment_index() != 0) {
56+
return Error::InvalidExternalData;
57+
}
58+
return tensors->Get(i);
59+
}
60+
}
61+
return Error::NotFound;
62+
}
63+
64+
Result<const TensorLayout> create_tensor_layout(
65+
const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) {
66+
ScalarType scalar_type =
67+
static_cast<ScalarType>(tensor_metadata->scalar_type());
68+
const int dim = tensor_metadata->sizes()->size();
69+
const auto serialized_sizes = tensor_metadata->sizes()->data();
70+
const auto serialized_dim_order = tensor_metadata->dim_order()->data();
71+
return TensorLayout::create(
72+
Span<const int32_t>(serialized_sizes, dim),
73+
Span<const uint8_t>(serialized_dim_order, dim),
74+
scalar_type);
75+
}
76+
77+
} // namespace
78+
79+
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
80+
const char* key) const {
81+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
82+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
83+
if (!metadata_res.ok()) {
84+
return metadata_res.error();
85+
}
86+
return create_tensor_layout(metadata_res.get());
87+
}
88+
89+
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
90+
const char* key) const {
91+
auto tensor_metadata = flat_tensor_->tensors();
92+
93+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
94+
get_flat_tensor_metadata(key, tensor_metadata);
95+
if (!metadata_res.ok()) {
96+
return metadata_res.error();
97+
}
98+
const auto metadata = metadata_res.get();
99+
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
100+
// Invalid segment_index/offset; malformed PTD file.
101+
return Error::InvalidExternalData;
102+
}
103+
104+
Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
105+
if (!tensor_layout_res.ok()) {
106+
return tensor_layout_res.error();
107+
}
108+
109+
// This FreeableBuffer doesn't own the underlying data, and will not free it,
110+
// which is why the free function is a nullptr.
111+
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
112+
// FreeableBuffer own it.
113+
return FreeableBuffer(
114+
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
115+
tensor_layout_res.get().nbytes(),
116+
nullptr);
117+
}
118+
119+
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
120+
ET_UNUSED const char* key,
121+
ET_UNUSED void* buffer,
122+
ET_UNUSED size_t size) const {
123+
return Error::NotImplemented;
124+
}
125+
126+
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
127+
return flat_tensor_->tensors()->size();
128+
}
129+
130+
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
131+
size_t index) const {
132+
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
133+
return Error::InvalidArgument;
134+
}
135+
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
136+
}
137+
138+
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
139+
DataLoader* loader) {
140+
// Load data map.
141+
size_t flatbuffer_offset = 0;
142+
size_t flatbuffer_size = 0;
143+
size_t segment_base_offset = 0;
144+
size_t segment_data_size = 0;
145+
{
146+
// Check header.
147+
Result<FreeableBuffer> header = loader->load(
148+
/*offset=*/0,
149+
FlatTensorHeader::kNumHeadBytes,
150+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
151+
if (!header.ok()) {
152+
return header.error();
153+
}
154+
Result<FlatTensorHeader> fh =
155+
FlatTensorHeader::Parse(header->data(), header->size());
156+
if (fh.ok()) {
157+
// The header has the data map size.
158+
flatbuffer_offset = fh->flatbuffer_offset;
159+
flatbuffer_size = fh->flatbuffer_size;
160+
segment_base_offset = fh->segment_base_offset;
161+
segment_data_size = fh->segment_data_size;
162+
} else if (fh.error() == Error::NotFound) {
163+
// No header, throw error.
164+
ET_LOG(Error, "No FlatTensorHeader found.");
165+
return fh.error();
166+
} else {
167+
// corruption, throw error.
168+
ET_LOG(Error, "Flat tensor header may be corrupt.");
169+
return fh.error();
170+
}
171+
}
172+
173+
// Load flatbuffer data as a segment.
174+
Result<FreeableBuffer> flat_tensor_data = loader->load(
175+
/*offset=*/0,
176+
flatbuffer_offset + flatbuffer_size,
177+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
178+
if (!flat_tensor_data.ok()) {
179+
return flat_tensor_data.error();
180+
}
181+
182+
// Make sure magic matches.
183+
if (!flat_tensor_flatbuffer::FlatTensorBufferHasIdentifier(
184+
flat_tensor_data->data())) {
185+
ET_LOG(
186+
Error,
187+
"FlatTensor identifier '%.4s' != expected '%.4s'",
188+
flatbuffers::GetBufferIdentifier(flat_tensor_data->data()),
189+
flat_tensor_flatbuffer::FlatTensorIdentifier());
190+
return Error::InvalidExternalData;
191+
}
192+
193+
// The flatbuffer data must start at an aligned address to ensure internal
194+
// alignment of flatbuffer fields.
195+
ET_CHECK_OR_RETURN_ERROR(
196+
is_aligned(flat_tensor_data->data()),
197+
InvalidArgument,
198+
"FlatTensor data 0x%p must be aligned to %zu",
199+
flat_tensor_data->data(),
200+
kMinimumAlignment);
201+
202+
// Get pointer to root of flatbuffer table.
203+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
204+
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
205+
206+
// Validate flatbuffer data.
207+
flatbuffers::Verifier verifier(
208+
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
209+
flat_tensor_data->size());
210+
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
211+
ET_CHECK_OR_RETURN_ERROR(
212+
ok,
213+
InvalidExternalData,
214+
"Verification failed; data may be truncated or corrupt");
215+
216+
// Get pointer to tensor metadata.
217+
const auto* s_tensor_metadata = flat_tensor->tensors();
218+
if (s_tensor_metadata == nullptr) {
219+
ET_LOG(Error, "FlatTensor has no tensor metadata.");
220+
return Error::InvalidExternalData;
221+
}
222+
223+
// Load constant data.
224+
const auto* s_data_segment = flat_tensor->segments();
225+
226+
// TODO(T214294528): Support multiple segments in FlatTensor.
227+
if (s_data_segment->size() != 1) {
228+
ET_LOG(
229+
Error,
230+
"FlatTensor has %u segments, only 1 supported.",
231+
s_data_segment->size());
232+
}
233+
// First segment size should be <= the total segment data size.
234+
int segment_size = s_data_segment->Get(0)->size();
235+
int segment_offset = s_data_segment->Get(0)->offset();
236+
if (segment_size > segment_data_size) {
237+
ET_LOG(
238+
Error,
239+
"FlatTensor segment size %d > segment data size %zu",
240+
segment_size,
241+
segment_data_size);
242+
}
243+
244+
Result<FreeableBuffer> data_ro = loader->load(
245+
/*offset=*/segment_base_offset + segment_offset,
246+
segment_size,
247+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
248+
if (!data_ro.ok()) {
249+
return data_ro.error();
250+
}
251+
252+
return FlatTensorDataMap(
253+
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
254+
}
255+
256+
} // namespace extension
257+
} // namespace executorch
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
#include <executorch/runtime/core/data_loader.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/core/tensor_layout.h>
17+
#include <executorch/runtime/platform/compiler.h>
18+
19+
#include <utility>
20+
21+
// Forward declare flatbuffer types. This is a public header and must not
22+
// include the generated flatbuffer header.
23+
namespace flat_tensor_flatbuffer {
24+
struct FlatTensor;
25+
} // namespace flat_tensor_flatbuffer
26+
27+
namespace executorch {
28+
namespace extension {
29+
30+
/**
31+
* A NamedDataMap implementation for FlatTensor-serialized data.
32+
*/
33+
class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
34+
public:
35+
/**
36+
* Creates a new DataMap that wraps FlatTensor data.
37+
*
38+
* @param[in] loader The DataLoader that wraps the FlatTensor file.
39+
* Note: the loader must outlive the FlatTensorDataMap instance.
40+
*/
41+
static executorch::runtime::Result<FlatTensorDataMap> load(
42+
executorch::runtime::DataLoader* loader);
43+
44+
ET_NODISCARD
45+
executorch::runtime::Result<const executorch::runtime::TensorLayout>
46+
get_metadata(const char* key) const override;
47+
ET_NODISCARD
48+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
49+
const char* key) const override;
50+
ET_NODISCARD executorch::runtime::Result<size_t>
51+
load_data_into(const char* key, void* buffer, size_t size) const override;
52+
53+
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
54+
const override;
55+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
56+
size_t index) const override;
57+
58+
FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;
59+
60+
~FlatTensorDataMap() = default;
61+
62+
private:
63+
FlatTensorDataMap(
64+
executorch::runtime::FreeableBuffer&& flat_tensor_data,
65+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
66+
executorch::runtime::FreeableBuffer&& data_ro)
67+
: flat_tensor_data_(std::move(flat_tensor_data)),
68+
flat_tensor_(flat_tensor),
69+
data_ro_(std::move(data_ro)) {}
70+
71+
// Not copyable or assignable.
72+
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
73+
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
74+
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;
75+
76+
// Serialized flat_tensor flatbuffer data.
77+
executorch::runtime::FreeableBuffer flat_tensor_data_;
78+
79+
// Flatbuffer representation of the flat_tensor.
80+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;
81+
82+
// Loaded read-only tensor data.
83+
executorch::runtime::FreeableBuffer data_ro_;
84+
};
85+
86+
} // namespace extension
87+
} // namespace executorch

extension/flat_tensor/targets.bzl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "flat_tensor_data_map",
6+
srcs = [
7+
"flat_tensor_data_map.cpp",
8+
],
9+
exported_headers = ["flat_tensor_data_map.h"],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:generated_headers",
12+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
13+
"//executorch/runtime/core:core",
14+
"//executorch/runtime/core:evalue",
15+
"//executorch/runtime/core:named_data_map",
16+
"//executorch/runtime/core/exec_aten:lib",
17+
"//executorch/runtime/core/exec_aten/util:tensor_util",
18+
],
19+
visibility = [
20+
"//executorch/...",
21+
],
22+
)

extension/flat_tensor/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ load(":targets.bzl", "define_common_targets")
66

77
oncall("executorch")
88

9-
define_common_targets()
9+
define_common_targets(is_fbcode=True)
1010

1111
python_unittest(
1212
name = "serialize",

0 commit comments

Comments
 (0)