Skip to content

Commit 31b3a87

Browse files
committed
[XNNPACK][Weights Cache] Enable in XNNPACK
Pull Request resolved: #9155 We enable the XNNPACK Weights cache in XNNPACK. the weights cache is initialized for the runtime with the named data map and a memory allocator (for now the memory allocator is not used, but i hope in the future this can be used to managed the memory for packed weights). Before Creating the runtime, we first initialize the weights cache, this sets the finalization state to false. As we add weight/bias tensors to the graph, we load them through the named data map in the weights cache, and keep a map of the pointer to the name. When XNNPACK Creates the runtime and packs the weights, it uses the weights_cache method look_up_or_insert. We use the pointers provided in the cache key to look up their names and append them together like ("weightsbias"). We then insert the packed weights with that key. In future look ups, we just use the pointer cached at the named pack tensor key, saving us from packing in the future. After creating the runtime and packing the weights, we finalize the cache. This sets is_finalized to true. We also free all unpacked buffers loaded from the named data map as they are no longer needed. We also keep reference counts for all the packed weights incrementing the packed weights which were used by this runtime. We return a vector of all the packed weight names to the xnn_executor runner. When the XNNExecutor is destroyed, we decrement the counts of the packed buffers and destroy them if necessary. Note that this feature is gated behind the XNN_ENABLE_WEIGHTS_CACHE flag. Since the weights_cache is a global member of the singleton xnnpack backend class, and it is also read/write, we add a mutex to ensure that access to the weights_cache is thread safe. We added a new mutex, so the mutex hiearchy is: workspace_mutex_ -> weights_cache_mutex_ ghstack-source-id: 271090604 @exported-using-ghexport Differential Revision: [D70885926](https://our.internmc.facebook.com/intern/diff/D70885926/)
1 parent f93bc1e commit 31b3a87

File tree

6 files changed

+87
-30
lines changed

6 files changed

+87
-30
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

+31-19
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ const uint8_t* getConstantDataPtr(
166166
const fb_xnnpack::XNNTensorValue* tensor_value,
167167
GraphPtr flatbuffer_graph,
168168
const uint8_t* constant_data_ptr,
169-
const NamedDataMap* named_data_map,
170-
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
169+
XNNWeightsCache* weights_cache) {
171170
auto buffer_idx = tensor_value->constant_buffer_idx();
172171
if (buffer_idx) {
173172
if (!constant_data_ptr) {
@@ -185,14 +184,12 @@ const uint8_t* getConstantDataPtr(
185184
return constant_data_ptr + offset;
186185
} else {
187186
const std::string &data_name = constant_data_offset->named_key()->str();
188-
Result<FreeableBuffer> buffer = named_data_map->get_data(data_name.c_str());
189-
if (!buffer.ok()) {
190-
ET_LOG(Error, "Failed to get constant data for key %s", data_name.c_str());
187+
Result<const uint8_t*> data_ptr = weights_cache->load_unpacked_data(data_name);
188+
if (!data_ptr.ok()){
189+
ET_LOG(Error, "Failed to load weights from cache");
191190
return nullptr;
192191
}
193-
const uint8_t* data_ptr = static_cast<const uint8_t*>(buffer.get().data());
194-
loaded_buffers_from_map.push_back(std::move(buffer.get()));
195-
return data_ptr;
192+
return data_ptr.get();
196193
}
197194
}
198195
}
@@ -214,8 +211,7 @@ Error defineTensor(
214211
std::vector<uint32_t>& input_ids,
215212
std::vector<uint32_t>& output_ids,
216213
CompileAllocator& allocator,
217-
const NamedDataMap* named_data_map,
218-
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
214+
XNNWeightsCache* weights_cache) {
219215
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
220216
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
221217

@@ -256,8 +252,7 @@ Error defineTensor(
256252
tensor_value,
257253
flatbuffer_graph,
258254
constant_data_ptr,
259-
named_data_map,
260-
loaded_buffers_from_map
255+
weights_cache
261256
);
262257

263258
xnn_status status;
@@ -1993,8 +1988,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19931988
const void* buffer_pointer,
19941989
size_t num_bytes,
19951990
XNNExecutor* executor,
1996-
MemoryAllocator* runtime_allocator,
1997-
const NamedDataMap* named_data_map,
1991+
XNNWeightsCache* weights_cache,
19981992
xnn_workspace_t workspace) {
19991993
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
20001994
const uint8_t* flatbuffer_data = nullptr;
@@ -2074,8 +2068,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20742068
input_ids,
20752069
output_ids,
20762070
compile_allocator,
2077-
named_data_map,
2078-
loaded_buffers_from_map);
2071+
weights_cache);
20792072

20802073
if (err != Error::Ok) {
20812074
return err;
@@ -2097,20 +2090,30 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20972090

20982091
xnn_runtime_t runtime_ptr = nullptr;
20992092

2093+
// XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2094+
// just manages the unpacked weights until the runtime is created.
2095+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2096+
xnn_weights_cache_t weights_cache_ptr =
2097+
weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get() : nullptr;
2098+
#else
2099+
xnn_weights_cache_t weights_cache_ptr = nullptr;
2100+
#endif
2101+
2102+
21002103
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
21012104
ET_CHECK_OR_RETURN_ERROR(
21022105
workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
21032106
status = xnn_create_runtime_v4(
21042107
subgraph.get(),
2105-
/*weight_cache=*/nullptr, // TODO - support weight cache
2108+
weights_cache_ptr,
21062109
workspace,
21072110
::executorch::extension::threadpool::get_pthreadpool(),
21082111
runtime_flags,
21092112
&runtime_ptr);
21102113
#else
21112114
status = xnn_create_runtime_v3(
21122115
subgraph.get(),
2113-
/*weight_cache=*/nullptr, // TODO - support weight cache
2116+
weights_cache_ptr,
21142117
::executorch::extension::threadpool::get_pthreadpool(),
21152118
runtime_flags,
21162119
&runtime_ptr);
@@ -2122,10 +2125,19 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21222125
"XNN Runtime creation failed with code: %s",
21232126
xnn_status_to_string(status));
21242127

2128+
auto packed_weights_names = weights_cache->finalize_for_runtime();
2129+
ET_CHECK_OR_RETURN_ERROR(
2130+
packed_weights_names.ok(),
2131+
Internal,
2132+
"Failed to finalize weights cache after creating the xnn runtime"
2133+
)
2134+
2135+
21252136
err = executor->initialize( // NOLINT: runtime_ptr is non-null
21262137
runtime_ptr,
21272138
std::move(input_ids),
2128-
std::move(output_ids));
2139+
std::move(output_ids),
2140+
std::move(packed_weights_names.get()));
21292141

21302142
return err;
21312143
};

backends/xnnpack/runtime/XNNCompiler.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/backends/xnnpack/runtime/XNNExecutor.h>
12+
#include <executorch/backends/xnnpack/runtime/XNNWeightsCache.h>
1213
#include <executorch/runtime/platform/compiler.h>
1314

1415
#include <xnnpack.h>
@@ -29,8 +30,7 @@ class XNNCompiler {
2930
const void* buffer_pointer,
3031
size_t num_bytes,
3132
XNNExecutor* executor,
32-
executorch::runtime::MemoryAllocator* runtime_allocator,
33-
const executorch::runtime::NamedDataMap* named_data_map,
33+
XNNWeightsCache* weights_cache,
3434
xnn_workspace_t workspace);
3535
};
3636

backends/xnnpack/runtime/XNNExecutor.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ using executorch::runtime::kTensorDimensionLimit;
3030
ET_NODISCARD Error XNNExecutor::initialize(
3131
xnn_runtime_t runtime,
3232
std::vector<uint32_t>&& input_ids,
33-
std::vector<uint32_t>&& output_ids) {
33+
std::vector<uint32_t>&& output_ids,
34+
std::vector<std::string>&& packed_data_names) {
3435
runtime_ = std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
3536
runtime, xnn_delete_runtime);
3637

@@ -51,6 +52,7 @@ ET_NODISCARD Error XNNExecutor::initialize(
5152
std::sort(output_ids_.begin(), output_ids_.end());
5253

5354
externals_.resize(input_ids_.size() + output_ids_.size());
55+
packed_data_names_ = std::move(packed_data_names);
5456

5557
return Error::Ok;
5658
}

backends/xnnpack/runtime/XNNExecutor.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class XNNExecutor {
3434
std::vector<uint32_t> input_ids_;
3535
std::vector<uint32_t> output_ids_;
3636
std::vector<xnn_external_value> externals_;
37+
std::vector<std::string> packed_data_names_;
3738

3839
public:
3940
XNNExecutor() = default;
@@ -46,6 +47,10 @@ class XNNExecutor {
4647
return output_ids_.size();
4748
}
4849

50+
inline std::vector<std::string> get_packed_data_names(){
51+
return packed_data_names_;
52+
}
53+
4954
/**
5055
* Initialize the XNNExecutor with a given runtime and input/output ids.
5156
* The input/output ids are expected to be sorted in order of their
@@ -54,7 +59,8 @@ class XNNExecutor {
5459
ET_NODISCARD executorch::runtime::Error initialize(
5560
xnn_runtime_t runtime,
5661
std::vector<uint32_t>&& input_ids,
57-
std::vector<uint32_t>&& output_ids);
62+
std::vector<uint32_t>&& output_ids,
63+
std::vector<std::string>&& packed_data_names);
5864

5965
/**
6066
* Prepares the arguments for runtime graph execution.

backends/xnnpack/runtime/XNNPACKBackend.cpp

+37-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
1313
#include <executorch/runtime/executor/pte_data_map.h>
14+
#include <executorch/backends/xnnpack/runtime/XNNWeightsCache.h>
1415

1516
#include <memory>
1617
#include <mutex>
@@ -31,6 +32,7 @@ using executorch::runtime::EValue;
3132
using executorch::runtime::FreeableBuffer;
3233
using executorch::runtime::Result;
3334
using executorch::runtime::NamedDataMap;
35+
using executorch::backends::xnnpack::delegate::XNNWeightsCache;
3436

3537
class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
3638
public:
@@ -81,13 +83,23 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
8183
}
8284

8385
const NamedDataMap* named_data_map = context.get_named_data_map();
86+
weights_cache_->initialize_for_runtime(
87+
context.get_runtime_allocator(),
88+
named_data_map
89+
);
8490

85-
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
8691
// This is needed to serialize access to xnn_create_runtime which is not
8792
// thread safe. This can heppen when multiple threads call init() on
8893
// the same backend instance.
94+
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
8995
const std::lock_guard<std::mutex> lock(workspace_mutex_);
9096
#endif
97+
98+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
99+
const std::lock_guard<std::mutex> lock(weights_cache_mutex_);
100+
#endif
101+
102+
91103
// Executor has been allocated but not constructed, ensure that runtime_ is
92104
// nullptr by constructing it in place here. NOTE: Since we use placement
93105
// new and since this type is not trivially destructible, we must call the
@@ -97,8 +109,7 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
97109
processed->data(),
98110
processed->size(),
99111
executor,
100-
context.get_runtime_allocator(),
101-
named_data_map,
112+
weights_cache_.get(),
102113
workspace_.get());
103114
// This backend does not need its processed data after compiling the model.
104115
processed->Free();
@@ -125,6 +136,10 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
125136
const std::lock_guard<std::mutex> lock(workspace_mutex_);
126137
#endif
127138

139+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
140+
const std::lock_guard<std::mutex> lock(weights_cache_mutex_);
141+
#endif
142+
128143
// Prepare Inputs/Outputs and Propagate Input Shapes
129144
Error err = executor->prepare_args(args);
130145
if (err != Error::Ok) {
@@ -145,16 +160,24 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
145160

146161
void destroy(DelegateHandle* handle) const override {
147162
if (handle != nullptr) {
148-
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
149163
// This is needed to serialize access to xnn_delete_runtime which is not
150164
// thread safe. This can heppen when multiple threads call destroy() on
151165
// the same backend instance.
166+
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
152167
const std::lock_guard<std::mutex> lock(workspace_mutex_);
153168
#endif
169+
170+
154171
auto executor = static_cast<xnnpack::delegate::XNNExecutor*>(handle);
172+
155173
#ifdef ENABLE_XNNPACK_PROFILING
156174
executor->print_avg_op_timings();
157175
#endif
176+
177+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
178+
const std::lock_guard<std::mutex> lock(weights_cache_mutex_);
179+
weights_cache_->delete_packed_data(executor->get_packed_data_names());
180+
#endif
158181
// XNNExecutor is not trivially destructible. Since this was constructed
159182
// manually in init(), we must destroy it manually here.
160183
executor->~XNNExecutor();
@@ -167,6 +190,16 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
167190
std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> workspace_{
168191
nullptr,
169192
&xnn_release_workspace};
193+
194+
// Weights cache is global to all delegate instances.
195+
mutable std::mutex weights_cache_mutex_;
196+
std::unique_ptr<XNNWeightsCache> weights_cache_ =
197+
std::make_unique<XNNWeightsCache>();
198+
199+
200+
// Lock Hiearchy for Mutexes:
201+
// workspace_mutex_
202+
// weights_cache_mutex_
170203
};
171204

172205
namespace {

backends/xnnpack/targets.bzl

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ def _get_preprocessor_flags():
66
Disable if someone explictly specified a config option,
77
else Enable otherwise
88
"""
9-
if native.read_config("executorch", "xnnpack_workspace_sharing", "0") == "0":
10-
return []
9+
preprocessor_flags = []
10+
if native.read_config("executorch", "xnnpack_workspace_sharing", "0") != "0":
11+
preprocessor_flags.append("-DENABLE_XNNPACK_SHARED_WORKSPACE")
12+
13+
if native.read_config("executorch", "xnnpack_weights_cache", "0") != "0":
14+
preprocessor_flags.append("-DENABLE_XNNPACK_WEIGHTS_CACHE")
1115

1216
# Enable if not disabled through config
13-
return ["-DENABLE_XNNPACK_SHARED_WORKSPACE"]
17+
return preprocessor_flags
1418

1519
def define_common_targets():
1620
runtime.cxx_library(

0 commit comments

Comments
 (0)