Skip to content

Commit a29dad6

Browse files
committed
Add named data map merge
Pull Request resolved: #11578 Add merge functionality to named_data_map interface and flat_tensor_data_map. pte_data_map does not implement merge. ghstack-source-id: 289886457 @exported-using-ghexport Differential Revision: [D76351013](https://our.internmc.facebook.com/intern/diff/D76351013/)
1 parent c4c4763 commit a29dad6

File tree

8 files changed

+264
-99
lines changed

8 files changed

+264
-99
lines changed

backends/xnnpack/test/runtime/test_xnn_data_separation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ TEST_F(DataSeparationTest, TestExternalData) {
8787
// Check that accessing keys out of bounds fails.
8888
EXPECT_EQ(data_map->get_key(2).error(), Error::InvalidArgument);
8989

90-
// Linear.weight
90+
// Linear.bias
9191
Result<FreeableBuffer> data0 = data_map->get_data(key0.get());
9292
EXPECT_EQ(data0.error(), Error::Ok);
93-
EXPECT_EQ(data0.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
93+
EXPECT_EQ(data0.get().size(), 12); // 3*4 (3 vector, 4 bytes per float)
9494

95-
// Linear.bias
95+
// Linear.weight
9696
Result<FreeableBuffer> data1 = data_map->get_data(key1.get());
9797
EXPECT_EQ(data1.error(), Error::Ok);
98-
EXPECT_EQ(data1.get().size(), 12); // 3*4 (3 vector, 4 bytes per float)
98+
EXPECT_EQ(data1.get().size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
9999

100100
// Check that accessing non-existent data fails.
101101
Result<FreeableBuffer> data2 = data_map->get_data("nonexistent");

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 132 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
#include <executorch/runtime/core/span.h>
2020
#include <executorch/runtime/platform/compiler.h>
2121

22+
using executorch::aten::ScalarType;
23+
using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
24+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
25+
using executorch::runtime::DataLoader;
2226
using executorch::runtime::Error;
2327
using executorch::runtime::FreeableBuffer;
2428
using executorch::runtime::Result;
2529
using executorch::runtime::Span;
2630

27-
using executorch::aten::ScalarType;
28-
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
29-
using executorch::runtime::DataLoader;
30-
3131
namespace executorch {
3232
namespace extension {
3333

@@ -103,82 +103,111 @@ Result<const TensorLayout> create_tensor_layout(
103103

104104
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout(
105105
executorch::aten::string_view key) const {
106-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
107-
key,
108-
flat_tensor_->named_data(),
109-
flat_tensor_->segments(),
110-
header_.segment_base_offset + header_.segment_data_size);
111-
if (!named_data.ok()) {
106+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
107+
return Error::NotFound;
108+
}
109+
auto index = key_to_map_index_.at(key.data());
110+
if (index == -1) {
111+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
112+
get_named_data(
113+
key,
114+
flat_tensor_->named_data(),
115+
flat_tensor_->segments(),
116+
header_.segment_base_offset + header_.segment_data_size);
117+
if (named_data.ok()) {
118+
return create_tensor_layout(named_data.get()->tensor_layout());
119+
}
112120
return named_data.error();
121+
} else {
122+
return merged_maps_[index]->get_tensor_layout(key);
113123
}
114-
return create_tensor_layout(named_data.get()->tensor_layout());
115124
}
116125

117126
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
118127
executorch::aten::string_view key) const {
119-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
120-
key,
121-
flat_tensor_->named_data(),
122-
flat_tensor_->segments(),
123-
header_.segment_base_offset + header_.segment_data_size);
124-
if (!named_data.ok()) {
128+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
129+
return Error::NotFound;
130+
}
131+
auto index = key_to_map_index_.at(key.data());
132+
if (index == -1) {
133+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
134+
get_named_data(
135+
key,
136+
flat_tensor_->named_data(),
137+
flat_tensor_->segments(),
138+
header_.segment_base_offset + header_.segment_data_size);
139+
if (named_data.ok()) {
140+
uint32_t segment_index = named_data.get()->segment_index();
141+
uint64_t segment_offset =
142+
flat_tensor_->segments()->Get(segment_index)->offset();
143+
uint64_t segment_size =
144+
flat_tensor_->segments()->Get(segment_index)->size();
145+
146+
return loader_->load(
147+
/*offset=*/header_.segment_base_offset + segment_offset,
148+
segment_size,
149+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
150+
}
125151
return named_data.error();
152+
} else {
153+
return merged_maps_[index]->get_data(key);
126154
}
127-
128-
uint32_t segment_index = named_data.get()->segment_index();
129-
uint64_t segment_offset =
130-
flat_tensor_->segments()->Get(segment_index)->offset();
131-
uint64_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
132-
133-
return loader_->load(
134-
/*offset=*/header_.segment_base_offset + segment_offset,
135-
segment_size,
136-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
137155
}
138156

139157
ET_NODISCARD Error FlatTensorDataMap::load_data_into(
140158
ET_UNUSED executorch::aten::string_view key,
141159
ET_UNUSED void* buffer,
142160
ET_UNUSED size_t size) const {
143-
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
144-
key,
145-
flat_tensor_->named_data(),
146-
flat_tensor_->segments(),
147-
header_.segment_base_offset + header_.segment_data_size);
148-
if (!named_data.ok()) {
149-
return named_data.error();
161+
if (key_to_map_index_.find(key.data()) == key_to_map_index_.end()) {
162+
return Error::NotFound;
150163
}
164+
auto index = key_to_map_index_.at(key.data());
165+
if (index == -1) {
166+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
167+
get_named_data(
168+
key,
169+
flat_tensor_->named_data(),
170+
flat_tensor_->segments(),
171+
header_.segment_base_offset + header_.segment_data_size);
172+
if (!named_data.ok()) {
173+
return named_data.error();
174+
}
151175

152-
uint32_t segment_index = named_data.get()->segment_index();
153-
uint64_t segment_offset =
154-
flat_tensor_->segments()->Get(segment_index)->offset();
176+
uint32_t segment_index = named_data.get()->segment_index();
177+
uint64_t segment_offset =
178+
flat_tensor_->segments()->Get(segment_index)->offset();
155179

156-
Result<const TensorLayout> tensor_layout =
157-
create_tensor_layout(named_data.get()->tensor_layout());
180+
Result<const TensorLayout> tensor_layout =
181+
create_tensor_layout(named_data.get()->tensor_layout());
158182

159-
if (!tensor_layout.ok()) {
160-
return tensor_layout.error();
161-
}
183+
if (!tensor_layout.ok()) {
184+
return tensor_layout.error();
185+
}
162186

163-
ET_CHECK_OR_RETURN_ERROR(
164-
size <= tensor_layout.get().nbytes(),
165-
InvalidArgument,
166-
"Buffer size %zu is smaller than tensor size %zu",
167-
size,
168-
tensor_layout.get().nbytes());
169-
170-
// Load mutable data.
171-
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
172-
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
173-
return loader_->load_into(
174-
header_.segment_base_offset + segment_offset,
175-
tensor_layout.get().nbytes(),
176-
info,
177-
buffer);
187+
ET_CHECK_OR_RETURN_ERROR(
188+
size <= tensor_layout.get().nbytes(),
189+
InvalidArgument,
190+
"Buffer size %zu is smaller than tensor size %zu",
191+
size,
192+
tensor_layout.get().nbytes());
193+
194+
// Load mutable data.
195+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
196+
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
197+
return loader_->load_into(
198+
header_.segment_base_offset + segment_offset,
199+
tensor_layout.get().nbytes(),
200+
info,
201+
buffer);
202+
} else {
203+
return merged_maps_[index]->load_data_into(key, buffer, size);
204+
}
178205
}
179206

180207
ET_NODISCARD Result<uint32_t> FlatTensorDataMap::get_num_keys() const {
181-
return flat_tensor_->named_data()->size();
208+
// Guaranteed safe, as the segment_index is a uint32_t, which means
209+
// that there can't be more than uint32_t keys.
210+
return static_cast<uint32_t>(key_to_map_index_.size());
182211
}
183212

184213
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
@@ -190,7 +219,40 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
190219
"Index %u out of range of size %u",
191220
index,
192221
num_keys);
193-
return flat_tensor_->named_data()->Get(index)->key()->c_str();
222+
223+
uint32_t current_index = 0;
224+
for (const auto& pair : key_to_map_index_) {
225+
if (current_index == index) {
226+
return pair.first.c_str();
227+
}
228+
current_index++;
229+
}
230+
return Error::NotFound;
231+
}
232+
233+
ET_NODISCARD Error FlatTensorDataMap::merge(const NamedDataMap* other) {
234+
ET_CHECK_OR_RETURN_ERROR(
235+
other != nullptr, InvalidArgument, "Merge error: other is nullptr.");
236+
237+
// Check if any duplicate keys exist.
238+
uint32_t num_keys = other->get_num_keys().get();
239+
240+
for (uint32_t i = 0; i < num_keys; i++) {
241+
const char* key = other->get_key(i).get();
242+
ET_CHECK_OR_RETURN_ERROR(
243+
key_to_map_index_.find(key) == key_to_map_index_.end(),
244+
InvalidArgument,
245+
"Merge error: key %s already exists in the named_data_map.",
246+
key);
247+
}
248+
// Place keys into the map.
249+
for (uint32_t i = 0; i < num_keys; i++) {
250+
const char* key = other->get_key(i).get();
251+
key_to_map_index_[key] = static_cast<int64_t>(merged_maps_.size());
252+
}
253+
254+
merged_maps_.push_back(other);
255+
return Error::Ok;
194256
}
195257

196258
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
@@ -261,8 +323,18 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
261323
InvalidExternalData,
262324
"FlatTensor segments is nullptr, malformed PTD file.");
263325

326+
// Add keys to the map.
327+
std::unordered_map<std::string, int64_t> key_to_map_index;
328+
for (int i = 0; i < flat_tensor->named_data()->size(); i++) {
329+
const auto* named_data = flat_tensor->named_data()->Get(i);
330+
key_to_map_index[named_data->key()->c_str()] = -1;
331+
}
264332
return FlatTensorDataMap(
265-
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
333+
fh.get(),
334+
std::move(flat_tensor_data.get()),
335+
flat_tensor,
336+
loader,
337+
std::move(key_to_map_index));
266338
}
267339

268340
} // namespace extension

extension/flat_tensor/flat_tensor_data_map.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ class FlatTensorDataMap final
9494
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
9595
uint32_t index) const override;
9696

97+
/**
98+
* Merge a named_data_map into the current one.
99+
* @param[in] other The named_data_map to merge.
100+
* @return Error indicating if the merge was successful or not.
101+
*
102+
* Note: The FlatTensorDataMap does not perform a deep copy; it holds a
103+
* reference to other, so other must outlive the FlatTensorDataMap instance.
104+
*/
105+
ET_NODISCARD executorch::runtime::Error merge(
106+
const NamedDataMap* other) override;
107+
97108
FlatTensorDataMap(FlatTensorDataMap&&) noexcept = default;
98109

99110
~FlatTensorDataMap() override = default;
@@ -103,11 +114,14 @@ class FlatTensorDataMap final
103114
const FlatTensorHeader& header,
104115
executorch::runtime::FreeableBuffer&& flat_tensor_data,
105116
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
106-
executorch::runtime::DataLoader* loader)
117+
executorch::runtime::DataLoader* loader,
118+
std::unordered_map<std::string, int64_t> key_to_map_index)
107119
: header_(header),
108120
flat_tensor_data_(std::move(flat_tensor_data)),
109121
flat_tensor_(flat_tensor),
110-
loader_(loader) {}
122+
loader_(loader),
123+
key_to_map_index_(std::move(key_to_map_index)),
124+
merged_maps_({}) {}
111125

112126
// Not copyable or assignable.
113127
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
@@ -125,6 +139,13 @@ class FlatTensorDataMap final
125139

126140
// Data loader, used to load segment data.
127141
executorch::runtime::DataLoader* loader_;
142+
143+
// Cache of keys to data map index.
144+
// index=-1 is used for the flat_tensor data map.
145+
std::unordered_map<std::string, int64_t> key_to_map_index_;
146+
147+
// Other NamedDataMaps.
148+
std::vector<const NamedDataMap*> merged_maps_;
128149
};
129150

130151
} // namespace extension

extension/flat_tensor/test/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,23 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
2121
add_custom_command(
2222
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
2323
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
24+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
25+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
2426
COMMAND
25-
${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul"
27+
${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul,ModuleLinear"
2628
--external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}" 2> /dev/null
2729
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
2830
)
2931

3032
add_custom_target(
3133
extension_flat_tensor_test_resources
32-
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
33-
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
34+
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
35+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
3436
)
3537

3638
set(test_env
37-
"ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
3839
"ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
40+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
3941
)
4042

4143
set(_test_srcs flat_tensor_data_map_test.cpp flat_tensor_header_test.cpp)

0 commit comments

Comments
 (0)