Skip to content

Commit db611cf

Browse files
committed
[executorch][flat_tensor] DataMap implementation
Pull Request resolved: #7900 DataMap implementation that * Loads a flat_tensor file * Populates a map with {key: tensor} and {key: TensorLayout}. * Makes tensor information available via the named_data_map.h interface. For now, DataMap doesn't store the DataLoader. - If/when tensors are in their own segments, DataMap should also store a DataLoader. ghstack-source-id: 264110292 Differential Revision: [D67064580](https://our.internmc.facebook.com/intern/diff/D67064580/)
1 parent f9e1c1d commit db611cf

File tree

8 files changed

+530
-2
lines changed

8 files changed

+530
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/flat_tensor/named_data_map/data_map.h>
10+
11+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
12+
#include <executorch/extension/flat_tensor/serialize/schema_generated.h>
13+
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
16+
#include <executorch/runtime/core/freeable_buffer.h>
17+
#include <executorch/runtime/core/result.h>
18+
#include <executorch/runtime/core/span.h>
19+
#include <executorch/runtime/platform/compiler.h>
20+
21+
using executorch::runtime::Error;
22+
using executorch::runtime::FreeableBuffer;
23+
using executorch::runtime::Result;
24+
using executorch::runtime::Span;
25+
26+
using executorch::aten::ScalarType;
27+
using executorch::runtime::DataLoader;
28+
using executorch::runtime::TensorLayout;
29+
30+
namespace executorch {
31+
namespace extension {
32+
33+
namespace {
34+
/**
35+
* FlatTensor data must be aligned to this value to properly parse it. Must be a
36+
* power of 2. Note that max_align_t is the alignment that malloc() and new
37+
* guarantee.
38+
*/
39+
constexpr size_t kMinimumAlignment = alignof(std::max_align_t);
40+
41+
bool IsAligned(const void* data) {
42+
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
43+
return addr % kMinimumAlignment == 0;
44+
}
45+
} // namespace
46+
47+
ET_NODISCARD Result<const TensorLayout> DataMap::get_metadata(
48+
const char* key) const {
49+
auto tensor_metadata = _flat_tensor->tensors();
50+
// Linear search by name here.
51+
for (int i = 0; i < tensor_metadata->size(); i++) {
52+
if (std::strcmp(
53+
tensor_metadata->Get(i)->fully_qualified_name()->c_str(), key) ==
54+
0) {
55+
// create TensorLayout.
56+
ScalarType scalar_type =
57+
static_cast<ScalarType>(tensor_metadata->Get(i)->scalar_type());
58+
const int dim = tensor_metadata->Get(i)->sizes()->size();
59+
const auto serialized_sizes = tensor_metadata->Get(i)->sizes()->data();
60+
const auto serialized_dim_order =
61+
tensor_metadata->Get(i)->dim_order()->data();
62+
return TensorLayout::create(
63+
Span<const int32_t>(serialized_sizes, dim),
64+
Span<const uint8_t>(serialized_dim_order, dim),
65+
scalar_type);
66+
}
67+
}
68+
return Error::InvalidArgument;
69+
}
70+
71+
ET_NODISCARD Result<FreeableBuffer> DataMap::get_data(const char* key) const {
72+
auto tensor_metadata = _flat_tensor->tensors();
73+
// Linear search by name here.
74+
int segment_index = -1;
75+
int offset = -1;
76+
int nbytes = 0;
77+
for (int i = 0; i < tensor_metadata->size(); i++) {
78+
if (std::strcmp(
79+
tensor_metadata->Get(i)->fully_qualified_name()->c_str(), key) ==
80+
0) {
81+
// Load data.
82+
segment_index = tensor_metadata->Get(i)->segment_index();
83+
// Assert one segment, for now.
84+
assert(segment_index == 0);
85+
offset = tensor_metadata->Get(i)->offset();
86+
87+
// Find nbytes.
88+
ScalarType scalar_type =
89+
static_cast<ScalarType>(tensor_metadata->Get(i)->scalar_type());
90+
const int dim = tensor_metadata->Get(i)->sizes()->size();
91+
const auto serialized_sizes = tensor_metadata->Get(i)->sizes()->data();
92+
const auto serialized_dim_order =
93+
tensor_metadata->Get(i)->dim_order()->data();
94+
Result<const TensorLayout> tensor_layout = TensorLayout::create(
95+
Span<const int32_t>(serialized_sizes, dim),
96+
Span<const uint8_t>(serialized_dim_order, dim),
97+
scalar_type);
98+
nbytes = tensor_layout.get().nbytes();
99+
}
100+
}
101+
102+
if (segment_index == -1 || offset == -1) {
103+
// Key doesn't exist.
104+
return Error::InvalidArgument;
105+
}
106+
return FreeableBuffer(
107+
static_cast<const uint8_t*>(_data_ro.data()) + offset, nbytes, nullptr);
108+
}
109+
110+
ET_NODISCARD Result<size_t>
111+
DataMap::load_data_into(const char* key, void* buffer, size_t size) const {
112+
return Error::NotImplemented;
113+
}
114+
115+
ET_NODISCARD Result<size_t> DataMap::get_num_keys() const {
116+
return _flat_tensor->tensors()->size();
117+
}
118+
119+
ET_NODISCARD Result<const char*> DataMap::get_key(size_t index) const {
120+
if (index < 0 || index >= _flat_tensor->tensors()->size()) {
121+
return Error::InvalidArgument;
122+
}
123+
return _flat_tensor->tensors()->Get(index)->fully_qualified_name()->c_str();
124+
}
125+
126+
/* static */ Result<DataMap> DataMap::load(DataLoader* loader) {
127+
// Load data map.
128+
size_t flatbuffer_offset = 0;
129+
size_t flatbuffer_size = 0;
130+
size_t segment_base_offset = 0;
131+
size_t segment_data_size = 0;
132+
{
133+
// Check header.
134+
Result<FreeableBuffer> header = loader->load(
135+
/*offset=*/0,
136+
FlatTensorHeader::kNumHeadBytes,
137+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
138+
if (!header.ok()) {
139+
return header.error();
140+
}
141+
Result<FlatTensorHeader> fh =
142+
FlatTensorHeader::Parse(header->data(), header->size());
143+
if (fh.ok()) {
144+
// The header has the data map size.
145+
flatbuffer_offset = fh->flatbuffer_offset;
146+
flatbuffer_size = fh->flatbuffer_size;
147+
segment_base_offset = fh->segment_base_offset;
148+
segment_data_size = fh->segment_data_size;
149+
} else if (fh.error() == Error::NotFound) {
150+
// No header, throw error.
151+
ET_LOG(Error, "No FlatTensorHeader found.");
152+
return fh.error();
153+
} else {
154+
// corruption, throw error.
155+
ET_LOG(Error, "Flat tensor header may be corrupt.");
156+
return fh.error();
157+
}
158+
}
159+
160+
ET_LOG(
161+
Info,
162+
"Flatbuffer offset %zu, size %zu, segment base offset %zu, segment size: %zu",
163+
flatbuffer_offset,
164+
flatbuffer_size,
165+
segment_base_offset,
166+
segment_data_size);
167+
168+
// Load flatbuffer data as a segment.
169+
Result<FreeableBuffer> flat_tensor_data = loader->load(
170+
/*offset=*/0,
171+
flatbuffer_offset + flatbuffer_size,
172+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
173+
if (!flat_tensor_data.ok()) {
174+
return flat_tensor_data.error();
175+
}
176+
177+
// Make sure magic matches.
178+
if (!flat_tensor_flatbuffer::FlatTensorBufferHasIdentifier(
179+
flat_tensor_data->data())) {
180+
ET_LOG(
181+
Error,
182+
"FlatTensor identifier '%.4s' != expected '%.4s'",
183+
flatbuffers::GetBufferIdentifier(flat_tensor_data->data()),
184+
flat_tensor_flatbuffer::FlatTensorIdentifier());
185+
return Error::InvalidExternalData;
186+
}
187+
188+
// The flatbuffer data must start at an aligned address to ensure internal
189+
// alignment of flatbuffer fields.
190+
ET_CHECK_OR_RETURN_ERROR(
191+
IsAligned(flat_tensor_data->data()),
192+
InvalidArgument,
193+
"FlatTensor data 0x%p must be aligned to %zu",
194+
flat_tensor_data->data(),
195+
kMinimumAlignment);
196+
197+
// Get pointer to root of flatbuffer table.
198+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
199+
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
200+
201+
// Get pointer to tensor metadata.
202+
const auto* s_tensor_metadata = flat_tensor->tensors();
203+
assert(s_tensor_metadata != nullptr);
204+
205+
// Load constant data.
206+
const auto* s_data_segment = flat_tensor->segments();
207+
208+
// Only support one segment for now.
209+
assert(s_data_segment->size() == 1);
210+
// First segment offset should be 0.
211+
int segment_offset = s_data_segment->Get(0)->offset();
212+
assert(segment_offset == 0);
213+
// First segment size should be <= the total segment data size.
214+
int segment_size = s_data_segment->Get(0)->size();
215+
assert(segment_size <= segment_data_size);
216+
217+
Result<FreeableBuffer> _data_ro = loader->load(
218+
/*offset=*/segment_base_offset + segment_offset,
219+
segment_size,
220+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
221+
if (!_data_ro.ok()) {
222+
return _data_ro.error();
223+
}
224+
225+
return DataMap(
226+
loader,
227+
segment_base_offset,
228+
std::move(flat_tensor_data.get()),
229+
flat_tensor,
230+
std::move(_data_ro.get()));
231+
}
232+
233+
DataMap::~DataMap() {
234+
_data_ro.Free();
235+
}
236+
237+
} // namespace extension
238+
} // namespace executorch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
#include <executorch/runtime/core/data_loader.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/core/tensor_layout.h>
17+
#include <executorch/runtime/platform/compiler.h>
18+
19+
#include <utility>
20+
21+
// Forward declare flatbuffer types. This is a public header and must not
22+
// include the generated flatbuffer header.
23+
namespace flat_tensor_flatbuffer {
24+
struct FlatTensor;
25+
} // namespace flat_tensor_flatbuffer
26+
27+
namespace executorch {
28+
namespace extension {
29+
30+
class DataMap final : public executorch::runtime::NamedDataMap {
31+
public:
32+
static executorch::runtime::Result<DataMap> load(
33+
executorch::runtime::DataLoader* loader);
34+
35+
ET_NODISCARD
36+
executorch::runtime::Result<const executorch::runtime::TensorLayout>
37+
get_metadata(const char* key) const override;
38+
ET_NODISCARD
39+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
40+
const char* key) const override;
41+
ET_NODISCARD executorch::runtime::Result<size_t>
42+
load_data_into(const char* key, void* buffer, size_t size) const override;
43+
44+
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
45+
const override;
46+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
47+
size_t index) const override;
48+
49+
DataMap(DataMap&&) noexcept = default;
50+
~DataMap() override;
51+
52+
private:
53+
DataMap(
54+
executorch::runtime::DataLoader* loader,
55+
size_t segment_base_offset,
56+
executorch::runtime::FreeableBuffer&& flat_tensor_data,
57+
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
58+
executorch::runtime::FreeableBuffer&& data_ro)
59+
: _loader(loader),
60+
_segment_base_offset(segment_base_offset),
61+
_flat_tensor_data(std::move(flat_tensor_data)),
62+
_flat_tensor(flat_tensor),
63+
_data_ro(std::move(data_ro)){};
64+
65+
// Not copyable or assignable.
66+
DataMap(const DataMap& rhs) = delete;
67+
DataMap& operator=(DataMap&& rhs) noexcept = delete;
68+
DataMap& operator=(const DataMap& rhs) = delete;
69+
70+
// Data loader used to load segment data.
71+
executorch::runtime::DataLoader* _loader;
72+
73+
// Segment base offset.
74+
size_t _segment_base_offset;
75+
76+
// Serialized flat_tensor data.
77+
executorch::runtime::FreeableBuffer _flat_tensor_data;
78+
79+
// Flatbuffer representation of the flat_tensor.
80+
const flat_tensor_flatbuffer::FlatTensor* _flat_tensor;
81+
82+
// Loaded read-only tensor data.
83+
executorch::runtime::FreeableBuffer _data_ro;
84+
};
85+
86+
} // namespace extension
87+
} // namespace executorch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "data_map",
6+
srcs = [
7+
"data_map.cpp",
8+
],
9+
exported_headers = ["data_map.h"],
10+
deps = [
11+
"//executorch/extension/flat_tensor/serialize:schema",
12+
"//executorch/extension/flat_tensor/serialize:serialize",
13+
"//executorch/extension/flat_tensor/serialize:generated_headers",
14+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
15+
"//executorch/runtime/core:core",
16+
"//executorch/runtime/core:evalue",
17+
"//executorch/runtime/core:named_data_map",
18+
"//executorch/runtime/core/exec_aten:lib",
19+
"//executorch/runtime/core/exec_aten/util:tensor_util",
20+
],
21+
visibility = [
22+
"//executorch/...",
23+
],
24+
)

extension/flat_tensor/test/TARGETS

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ load(":targets.bzl", "define_common_targets")
66

77
oncall("executorch")
88

9-
define_common_targets()
9+
define_common_targets(is_fbcode=True)
1010

1111
python_unittest(
1212
name = "serialize",

0 commit comments

Comments
 (0)