diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index a199e1aab01..972980570ec 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -19,5 +19,6 @@ python_library( "//executorch/exir/passes:const_prop_pass", "//executorch/exir/passes:memory_format_ops_pass", "//executorch/exir/program:program", + "//executorch/backends/transforms:utils", ], ) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py index b0f4779eb4c..6f31fe698ba 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py @@ -7,13 +7,22 @@ import operator import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node +from executorch.backends.xnnpack.utils.utils import ( + get_param_tensor, + get_tensor_name, + is_param_node, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - counter = 0 + constant_placeholders_to_delete = set() for conv in graph.nodes: # We want to discover a chain of conv -> batch_norm. # Only proceed if the current node is a conv node, and has a single @@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule): assert len(conv.args) == 9 conv_weight = get_param_tensor(self.exported_program, conv.args[1]) + conv_weight_name = get_tensor_name(self.exported_program, conv.args[1]) assert conv_weight is not None conv_bias = get_param_tensor(self.exported_program, conv.args[2]) + conv_bias_name = get_tensor_name(self.exported_program, conv.args[2]) # Get the parameters from the batchnorm op assert ( @@ -95,23 +106,43 @@ def call(self, graph_module: torch.fx.GraphModule): bn_bias, is_transpose, ) + fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_") + if conv_bias_name == "": + fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace( + ".", "_" + ) + else: + fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_") # Modify the graph by updating the weight and bias of conv op # with the fused weight and bias params, and replacing all the users # of getitem(batchnorm) with the conv op. - with graph.inserting_before(conv): - fused_weight_name = f"_fused_with_bn_weight_{counter}" - graph_module.register_parameter(fused_weight_name, fused_weight) - fused_weight_node = graph.get_attr(fused_weight_name) - fused_bias_name = f"_fused_with_bn_bias_{counter}" - graph_module.register_parameter(fused_bias_name, fused_bias) - fused_bias_node = graph.get_attr(fused_bias_name) - - # Update the weight and bias of conv op - conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else []) - conv_args[1] = fused_weight_node - conv_args[2] = fused_bias_node - conv.args = tuple(conv_args) + with graph.inserting_before(conv.args[1]): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight, + ) + if fused_bias is not None: + fused_conv_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias, + ) + else: + fused_conv_bias_node = None + + conv.args = ( + conv.args[0], + fused_conv_weight_node, + fused_conv_bias_node, + *conv.args[3:], + ) + # Remove any use of batchnorm from the graph for user in bn.users.copy(): assert user.target == operator.getitem @@ -119,8 +150,13 @@ def call(self, graph_module: torch.fx.GraphModule): graph.erase_node(user) graph.erase_node(bn) + constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5]) - counter += 1 + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and (len(node.users) == 0): + delete_constant_placeholder(self.exported_program, node) graph_module.recompile() # To Regenerate meta data and shape information, retrace module diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 0a825a94bef..ec39d287346 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -34,11 +34,16 @@ check_or_raise, get_input_node, get_param_tensor, + get_tensor_name, is_param_node, PERM_NCHW_TO_NHWC, ) -from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID +from executorch.backends.xnnpack.utils.xnnpack_constants import ( + UINT64_MAX, + XNN_INVALID_VALUE_ID, +) +from executorch.exir._serialize._named_data_store import NamedDataStore from torch.export import ExportedProgram XNN_TYPE_MAP = { @@ -46,8 +51,6 @@ } from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( - _aligned_size, - _pad_to, CONSTANT_TENSOR_ALIGNMENT, ) @@ -86,11 +89,11 @@ def __init__( self, exported_program: ExportedProgram, external_ids: Dict, - constant_data_bytes: bytearray, + named_data_store: NamedDataStore, ) -> None: self._external_ids = external_ids or {} self._exported_program = exported_program or None - self._constant_data_bytes = constant_data_bytes + self._named_data_store = named_data_store @property def external_ids(self) -> Dict: @@ -579,11 +582,16 @@ def get_serialized_buffer_index( ctypes.POINTER(array_type), ).contents - offset = len(self._constant_data_bytes) + named_key = get_tensor_name(self.exported_program, get_attr_node) + if named_key == "": + raise ValueError(f"Tensor from node: {get_attr_node} has no name") + size = const_val.untyped_storage().nbytes() - xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size)) - self._constant_data_bytes.extend( - _pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT)) + xnn_graph.constant_data.append( + ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key) + ) + self._named_data_store.add_named_data( + named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT ) return buffer_idx diff --git a/backends/xnnpack/runtime/XNNWeightsCache.cpp b/backends/xnnpack/runtime/XNNWeightsCache.cpp new file mode 100644 index 00000000000..f2842851d3a --- /dev/null +++ b/backends/xnnpack/runtime/XNNWeightsCache.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace xnnpack { +namespace delegate { + +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; + +XNNWeightsCache::XNNWeightsCache() { + weights_cache_.context = this; + weights_cache_.look_up = (size_t(*)( + void*, const xnn_weights_cache_look_up_key*))XNNWeightsCache::look_up; + weights_cache_.reserve_space = + (void* (*)(void*, size_t))XNNWeightsCache::reserve_space; + weights_cache_.look_up_or_insert = + (size_t(*)(void*, const xnn_weights_cache_look_up_key*, void*, size_t)) + XNNWeightsCache::look_up_or_insert; + weights_cache_.is_finalized = (bool (*)(void*))XNNWeightsCache::is_finalized; + weights_cache_.offset_to_addr = + (void* (*)(void*, size_t))XNNWeightsCache::offset_to_addr; + weights_cache_.delete_cache = + (enum xnn_status(*)(void*))XNNWeightsCache::delete_cache; +} + +Error XNNWeightsCache::initialize_for_runtime( + MemoryAllocator* runtime_allocator, + const NamedDataMap* named_data_map) { + runtime_allocator_ = runtime_allocator; + named_data_map_ = named_data_map; + is_finalized_ = false; + + return Error::Ok; +} + +Result> XNNWeightsCache::finalize_for_runtime() { + is_finalized_ = true; + + // All data has been packed by create_runtime + // so we clear the unpacked data as it is no longer needed + for (FreeableBuffer& buffer : unpacked_data_) { + buffer.Free(); + } + unpacked_data_.clear(); + unpacked_data_to_name_.clear(); + + std::vector packed_data_names; + // update the reference count of all the packed data + // used by this runtime + for (auto& entry : name_to_packed_data_metadata_) { + if (entry.second.in_current_runtime) { + entry.second.ref_count++; + entry.second.in_current_runtime = false; + packed_data_names.push_back(entry.first); + } + } + + return packed_data_names; +} + +Result XNNWeightsCache::load_unpacked_data( + const std::string& name) { + Result named_data = named_data_map_->get_data(name.c_str()); + if (!named_data.ok()) { + ET_LOG(Error, "Failed to load constant data for key %s", name.c_str()); + return Error::InvalidExternalData; + } + const uint8_t* data_pointer = + static_cast(named_data.get().data()); + unpacked_data_.push_back(std::move(named_data.get())); + unpacked_data_to_name_[data_pointer] = name; + + return data_pointer; +} + +Error XNNWeightsCache::delete_packed_data( + const std::vector& packed_data_names) { + if (!is_finalized_) { + ET_LOG( + Error, + "Error, attempted to delete packed data from the cache but the cache is not finalized"); + return Error::InvalidArgument; + } + for (const std::string& name : packed_data_names) { + auto entry = name_to_packed_data_metadata_.find(name); + if (entry == name_to_packed_data_metadata_.end()) { + ET_LOG( + Error, + "Error, attempted to deleted packed data: %s, from the cache but it wasn't found", + name.c_str()); + return Error::InvalidArgument; + } else { + entry->second.ref_count--; + if (entry->second.ref_count == 0) { + void* packed_data_ptr = packed_data_ptrs_[entry->second.offset]; + // Erase the key/value from the map frees the pointer holding the packed + // data + packed_pointer_to_container_.erase(packed_data_ptr); + // remove the pointer from the packed_data_ptrs_ + packed_data_ptrs_[entry->second.offset] = nullptr; + // Erase the name to packed metadata entry + name_to_packed_data_metadata_.erase(entry->first); + } + } + } + + return Error::Ok; +} + +size_t XNNWeightsCache::look_up( + XNNWeightsCache* context, + const xnn_weights_cache_look_up_key* cache_key) { + const void* unpacked_weights_ptr = cache_key->kernel; + const void* unpacked_bias_ptr = cache_key->bias; + auto entry = context->unpacked_data_to_name_.find(unpacked_weights_ptr); + + // Check if weight_pointer has been cached + if (entry == context->unpacked_data_to_name_.end()) { + return SIZE_MAX; + } + + std::string weight_bias_name = entry->second; + + // Check if bias_pointer has been cached + if (unpacked_bias_ptr != nullptr) { + auto bias_entry = context->unpacked_data_to_name_.find(unpacked_bias_ptr); + if (bias_entry != context->unpacked_data_to_name_.end()) { + weight_bias_name.append(bias_entry->second); + } + } + + // check if weight_bias_name has been packed already + auto packed_weight_entry = + context->name_to_packed_data_metadata_.find(weight_bias_name); + if (packed_weight_entry == context->name_to_packed_data_metadata_.end()) { + return SIZE_MAX; + } + packed_weight_entry->second.in_current_runtime = true; + + return packed_weight_entry->second.offset; +} + +void* XNNWeightsCache::reserve_space(XNNWeightsCache* context, size_t n) { + // MemoryAllocator* allocator = context->runtime_allocator_; + // void* reserved_pointer = allocator->allocate(n, + // context->kPackedAllocationAlignment); + + // return reserved_pointer; + std::string data_container; + data_container.resize(n + context->kPackedAllocationAlignment); + void* maybe_aligned_space = data_container.data(); + void* aligned_space = (void*)((intptr_t)maybe_aligned_space + 64 - + (intptr_t)maybe_aligned_space % 64); + + context->packed_pointer_to_container_[aligned_space] = + std::move(data_container); + return aligned_space; +} + +size_t XNNWeightsCache::look_up_or_insert( + XNNWeightsCache* context, + const xnn_weights_cache_look_up_key* cache_key, + void* ptr, + size_t size) { + size_t offset = context->look_up(context, cache_key); + + if (offset != SIZE_MAX) { + void* saved_ptr = context->offset_to_addr(context, offset); + if (0 == memcmp(ptr, saved_ptr, size)) { + return offset; + } + // Failure, cache is out of date + return SIZE_MAX; + } + + // Add to Cache if it is not finalized + size_t next_offset = context->packed_data_ptrs_.size(); + auto entry = context->unpacked_data_to_name_.find(cache_key->kernel); + + // Check if weight_pointer has been cached + if (entry != context->unpacked_data_to_name_.end()) { + std::string weight_bias_name = entry->second; + if (cache_key->bias != nullptr) { + auto bias_entry = context->unpacked_data_to_name_.find(cache_key->bias); + if (bias_entry != context->unpacked_data_to_name_.end()) { + weight_bias_name.append(bias_entry->second); + } + } + PackedDataMeta packed_data_metadata = { + .offset = next_offset, + .ref_count = + 0, // ref_count is only incremented after finalizing for runtime + .in_current_runtime = true}; + context->name_to_packed_data_metadata_[weight_bias_name] = + packed_data_metadata; + } else { + ET_LOG( + Info, + "Warning: Unpacked weight and bias were not registered with names, " + "this will add new cache entries for packed data and may affect performance."); + } + context->packed_data_ptrs_.push_back(ptr); + + return next_offset; +} + +bool XNNWeightsCache::is_finalized(XNNWeightsCache* context) { + return context->is_finalized_; +} + +void* XNNWeightsCache::offset_to_addr(XNNWeightsCache* context, size_t offset) { + return context->packed_data_ptrs_[offset]; +} + +enum xnn_status XNNWeightsCache::delete_cache(XNNWeightsCache* context) { + return xnn_status_success; +} + +} // namespace delegate +} // namespace xnnpack +} // namespace backends +} // namespace executorch diff --git a/backends/xnnpack/runtime/XNNWeightsCache.h b/backends/xnnpack/runtime/XNNWeightsCache.h new file mode 100644 index 00000000000..bc00ac15fd0 --- /dev/null +++ b/backends/xnnpack/runtime/XNNWeightsCache.h @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace xnnpack { +namespace delegate { + +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; + +struct PackedDataMeta { + size_t offset; + // Count number of xnn_runtime_t this packed data is used in + size_t ref_count; + // true if this packed data was inserted or looked up for the + // current runtime being created + bool in_current_runtime; +}; + +class XNNWeightsCache { + public: + XNNWeightsCache(); + + /** + * Initializes the XNNWeightsCache for the next xnn_create_runtime + */ + Error initialize_for_runtime( + MemoryAllocator* runtime_allocator, + const NamedDataMap* named_data_map); + + /** + * Finalizes the weights cache after the weights have been packed + * in xnn_create_runtime. + * + * This should only be called after creating the runtime. Returns + * the name of all the packed weights used by this runtime + */ + Result> finalize_for_runtime(); + + // Taken from XNN_ALLOCATION_ALIGNMENT in xnnpack/common.h + static const size_t kPackedAllocationAlignment = 64; + + /** + * Returns XNNPACK's underlying weights_cache pointer + */ + inline xnn_weights_cache_t get() { + return (xnn_weights_cache_t)&weights_cache_; + } + + /** + * Returns the number of unpacked data + */ + inline size_t get_num_unpacked_data() { + return unpacked_data_.size(); + }; + + /** + * Returns the names of all unpacked data + */ + inline std::vector get_unpacked_data_names() { + std::vector names; + for (const auto& pair : unpacked_data_to_name_) { + names.push_back(pair.second); + } + return names; + }; + + /** + * Returns the packed data names + */ + inline std::vector get_packed_data_names() { + std::vector names; + for (const auto& pair : name_to_packed_data_metadata_) { + names.push_back(pair.first); + } + return names; + }; + + /** + * Loads unpacked named data from the NamedDataMap into this XNNWeightsCache + * and returns a pointer to the unpacked data. This unpacked data is given + * to XNNPACK's define_tensor APIs, and used as the cache key for + * look_up_or_insert. + * @param[in] name The name of the data to load + * @param[out] out the pointer to the unpacked data that was loaded + */ + Result load_unpacked_data(const std::string& name); + + /** + * Deletes the packed data associated with the names given. + * Decrements the ref_count if the packed data is used by other + * models + * + */ + Error delete_packed_data(const std::vector& packed_names); + + private: + // Runtime Allocator used to reserve memory for packed weights + MemoryAllocator* runtime_allocator_; + + // Named Data Map used to load named data + const NamedDataMap* named_data_map_; + + // Map of unpacked pointers to the data name + std::unordered_map unpacked_data_to_name_; + // Map of data names to offset into the packed data + std::unordered_map name_to_packed_data_metadata_; + // Vector holding list of pointers to the packed data + std::vector packed_data_ptrs_; + // vector holding list of strings which are containers for packed_data_ptrs + std::unordered_map packed_pointer_to_container_; + // Vector hodling list of unpacked freeable buffers + std::vector unpacked_data_; + // xnnpack's weight cache provider + xnn_weights_cache_provider weights_cache_; + // whether or not the weight cache is finalized + bool is_finalized_; + + // Function pointers to override XNNPACK's default xnn_weights_cache_provider + // functions. + static size_t look_up( + XNNWeightsCache* context, + const xnn_weights_cache_look_up_key* cache_key); + + static void* reserve_space(XNNWeightsCache* context, size_t n); + + static size_t look_up_or_insert( + XNNWeightsCache* context, + const xnn_weights_cache_look_up_key* cache_key, + void* ptr, + size_t size); + + static bool is_finalized(XNNWeightsCache* context); + + static void* offset_to_addr(XNNWeightsCache* context, size_t offset); + + static enum xnn_status delete_cache(XNNWeightsCache* context); +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace backends +} // namespace executorch diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 5a43481b98d..193656c30b1 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -316,11 +316,20 @@ table XNNLeakyReLU { table ConstantDataOffset { // Constant data offsets are relative to the constant data base offset provided // in the XNNPACKHeader. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string offset: uint64; // The size in bytes of valid data starting at the offset. The constant data // may be followed by padding before the next piece of constant data size: uint64; + + // unique string id used to query the offset from the named data store. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string + named_key: string; } table XNNGraph { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 3276dac7869..3cb572c66ef 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -470,6 +470,7 @@ class XValue: class ConstantDataOffset: offset: int size: int + named_key: str = "" @dataclass diff --git a/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp b/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp new file mode 100644 index 00000000000..ca149a67b5e --- /dev/null +++ b/backends/xnnpack/test/runtime/test_xnn_weights_cache.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::backends::xnnpack::delegate::XNNWeightsCache; +using executorch::extension::FileDataLoader; +using executorch::extension::testing::TempFile; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::Result; +using executorch::runtime::internal::PteDataMap; + +class XNNWeightsCacheTest : public ::testing::Test { + protected: + void SetUp() override { + // Creating a NamedDataMap from scratch is a little bit convoluted, so + // we copied a lot of setup from test_pte_data_map.cpp + + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Create a sample Program with only named_data and segments. Technically + // not a valid Program; only used to test the PteDataMap. + // Create named data. + std::array, 2> + named_data_arr = { + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "weight", /*segment_index=*/0), + executorch_flatbuffer::CreateNamedDataDirect( + builder_, "bias", /*segment_index=*/1), + }; + const auto named_data = + builder_.CreateVector(named_data_arr.data(), named_data_arr.size()); + + // Create segments. + std::array, 2> + segment_arr = {// @lint-ignore CLANGTIDY facebook-hte-BadArgumentComment + executorch_flatbuffer::CreateDataSegment( + builder_, /*offset=*/0, /*size=*/kSegmentSizes[0]), + // @lint-ignore CLANGTIDY facebook-hte-BadArgumentComment + executorch_flatbuffer::CreateDataSegment( + builder_, + /*offset=*/kSegmentAlignment * 2, + /*size=*/kSegmentSizes[1])}; + const auto segments = + builder_.CreateVector(segment_arr.data(), segment_arr.size()); + + // Create Program. + const auto program = executorch_flatbuffer::CreateProgram( + builder_, 0, 0, 0, 0, segments, 0, 0, named_data); + + builder_.Finish(program); + program_ = executorch_flatbuffer::GetProgram(builder_.GetBufferPointer()); + + // Create sample segment data. + for (int i = 0; i < kSegmentSizes[0]; i++) { + sample_data_[i] = 1; + } + for (int i = kSegmentOffsets[1]; i < kSegmentOffsets[1] + kSegmentSizes[1]; + i++) { + sample_data_[i] = 2; + } + TempFile tf(sample_data_.data(), sizeof(sample_data_)); + + // Wrap the sample data in a loader. + Result loader = + FileDataLoader::from(tf.path().c_str(), kSegmentAlignment); + ASSERT_EQ(loader.error(), Error::Ok); + data_map_loader_ = + std::make_unique(std::move(loader.get())); + + Result data_map = PteDataMap::create( + data_map_loader_.get(), + 0, + program_->named_data(), + program_->segments()); + ASSERT_EQ(data_map.error(), Error::Ok); + data_map_ = std::make_unique(std::move(data_map.get())); + + memory_allocator_ = std::make_unique( + memory_allocator_data_.size(), memory_allocator_data_.data()); + + xnn_status status = xnn_initialize(nullptr); + ASSERT_EQ(status, xnn_status_success); + } + + void BuildAndRunGraphWithWeightsCache( + XNNWeightsCache& weight_cache, + const std::vector& batches, + size_t input_channels, + size_t output_channels, + float* input_data, + float* output_data) { + // Defining subgraph + xnn_subgraph_t subgraph_ptr = nullptr; + xnn_status status = xnn_create_subgraph( + /*external_value_ids=*/2, + /*flags=*/0, + &subgraph_ptr); + ASSERT_EQ(status, xnn_status_success); + std::unique_ptr subgraph( + subgraph_ptr, &xnn_delete_subgraph); + + // Define tensors + // Define input + uint32_t input_id; + std::vector input_dims(batches); + input_dims.push_back(input_channels); + status = xnn_define_tensor_value( + subgraph_ptr, + xnn_datatype_fp32, + input_dims.size(), + input_dims.data(), + nullptr, + 0, + XNN_VALUE_FLAG_EXTERNAL_INPUT, + &input_id); + + // Define weight + uint32_t weight_id; + Result weight_pointer = + weight_cache.load_unpacked_data("weight"); + ASSERT_TRUE(weight_pointer.ok()); + ASSERT_TRUE(weight_pointer.get() != nullptr); + std::vector weight_dims{output_channels, input_channels}; + status = xnn_define_tensor_value( + subgraph_ptr, + xnn_datatype_fp32, + weight_dims.size(), + weight_dims.data(), + weight_pointer.get(), + XNN_INVALID_VALUE_ID, + 0, + &weight_id); + ASSERT_EQ(status, xnn_status_success); + + // Define bias + uint32_t bias_id; + Result bias_pointer = + weight_cache.load_unpacked_data("bias"); + ASSERT_TRUE(bias_pointer.ok()); + std::vector bias_dims{output_channels}; + status = xnn_define_tensor_value( + subgraph_ptr, + xnn_datatype_fp32, + bias_dims.size(), + bias_dims.data(), + bias_pointer.get(), + XNN_INVALID_VALUE_ID, + 0, + &bias_id); + + // Define output tensor + uint32_t output_id; + std::vector output_dims(batches); + output_dims.push_back(output_channels); + status = xnn_define_tensor_value( + subgraph_ptr, + xnn_datatype_fp32, + output_dims.size(), + output_dims.data(), + nullptr, + 1, + XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id); + + // create xecond fully connected + status = xnn_define_fully_connected( + subgraph_ptr, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + input_id, + weight_id, + bias_id, + output_id, + 0); + // Create and Pack Weights + xnn_runtime_t runtime_ptr = nullptr; + status = xnn_create_runtime_v3( + subgraph_ptr, weight_cache.get(), nullptr, 0, &runtime_ptr); + Result> packed_weights_added = + weight_cache.finalize_for_runtime(); + ASSERT_TRUE(packed_weights_added.ok()); + ASSERT_EQ(packed_weights_added.get().size(), 1); + ASSERT_EQ(packed_weights_added.get()[0], "weightbias"); + + auto runtime = std::unique_ptr( + runtime_ptr, xnn_delete_runtime); + + const std::array external = { + xnn_external_value{0, input_data}, + xnn_external_value{1, output_data}, + }; + + status = xnn_reshape_runtime(runtime.get()); + status = + xnn_setup_runtime_v2(runtime.get(), external.size(), external.data()); + + ASSERT_EQ(status, xnn_status_success); + status = xnn_invoke_runtime(runtime.get()); + ASSERT_EQ(status, xnn_status_success); + } + + // Program builder constants. + static constexpr int kSegmentAlignment = 16; + static constexpr std::array kSegmentSizes{384, 128}; + static constexpr std::array kSegmentOffsets{0, kSegmentAlignment * 2}; + std::array sample_data_; + + // Program builder. + flatbuffers::FlatBufferBuilder builder_; + const executorch_flatbuffer::Program* program_; + + // Data loader for the sample data. + std::unique_ptr data_map_loader_; + + // PteDataMap + std::unique_ptr data_map_; + + // MemoryAllocator + std::array memory_allocator_data_; + std::unique_ptr memory_allocator_; +}; + +TEST_F(XNNWeightsCacheTest, ReusePackedWeights) { + XNNWeightsCache weight_cache; + size_t padding = 32; + + std::vector batches{1, 2, 3}; + size_t num_batches = 1; + for (size_t batch_dim : batches) { + num_batches *= batch_dim; + } + size_t input_channels = 3; + size_t output_channels = 4; + std::vector input_tensor(num_batches * input_channels + padding, 1.0f); + std::vector output_tensor(num_batches * output_channels, 0.0f); + float* input_data = input_tensor.data(); + float* output_data = output_tensor.data(); + weight_cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + BuildAndRunGraphWithWeightsCache( + weight_cache, + batches, + input_channels, + output_channels, + input_data, + output_data); + + weight_cache.initialize_for_runtime(memory_allocator_.get(), data_map_.get()); + BuildAndRunGraphWithWeightsCache( + weight_cache, + batches, + input_channels, + output_channels, + input_data, + output_data); + ASSERT_EQ(weight_cache.get_num_unpacked_data(), 0); + weight_cache.delete_packed_data(weight_cache.get_packed_data_names()); + std::vector packed_data_names = + weight_cache.get_packed_data_names(); + // Packed Data Still exists because it has a ref count of 2 + ASSERT_EQ(packed_data_names.size(), 1); + weight_cache.delete_packed_data(weight_cache.get_packed_data_names()); + packed_data_names = weight_cache.get_packed_data_names(); + ASSERT_EQ(packed_data_names.size(), 0); +} diff --git a/backends/xnnpack/test/targets.bzl b/backends/xnnpack/test/targets.bzl index 30ce970a842..58589b70607 100644 --- a/backends/xnnpack/test/targets.bzl +++ b/backends/xnnpack/test/targets.bzl @@ -30,3 +30,16 @@ def define_common_targets(): "//executorch/backends/xnnpack:xnnpack_backend", ], ) + + runtime.cxx_test( + name = "test_xnn_weights_cache", + srcs = ["runtime/test_xnn_weights_cache.cpp"], + deps = [ + third_party_dep("XNNPACK"), + "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/runtime/executor:pte_data_map", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/testing_util:temp_file", + "//executorch/schema:program", + ], + ) diff --git a/backends/xnnpack/utils/gen_xnnpack_constants.sh b/backends/xnnpack/utils/gen_xnnpack_constants.sh index 6be9d4519f3..5fa92e5b038 100644 --- a/backends/xnnpack/utils/gen_xnnpack_constants.sh +++ b/backends/xnnpack/utils/gen_xnnpack_constants.sh @@ -26,5 +26,6 @@ } > xnnpack_constants.py echo UINT32_MAX = 4294967295 >> xnnpack_constants.py +echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b802d73c16b..fab95618807 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -131,6 +131,22 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") +def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str: + if node is None: + return "" + if is_param(exp_prog, node): + return exp_prog.graph_signature.inputs_to_parameters[node.name] + elif is_buffer(exp_prog, node): + return exp_prog.graph_signature.inputs_to_buffers[node.name] + elif is_lifted_tensor_constant(exp_prog, node): + return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name] + else: + assert isinstance(node.target, str) + return node.target + + return "" + + def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: """ Returns the source fn of the given node, return None if something goes wrong diff --git a/backends/xnnpack/utils/xnnpack_constants.py b/backends/xnnpack/utils/xnnpack_constants.py index 351cc8ad897..364819a2435 100644 --- a/backends/xnnpack/utils/xnnpack_constants.py +++ b/backends/xnnpack/utils/xnnpack_constants.py @@ -6,8 +6,11 @@ # Auto-generated by gen_xnnpack_constants.sh script. Do not modify UINT32_MAX = 4294967295 +UINT64_MAX = 18446744073709551615 +XNN_EXTRA_BYTES = 128 XNN_EXTRA_BYTES = 16 XNN_MAX_TENSOR_DIMS = 6 +XNN_INVALID_VALUE_ID = UINT32_MAX XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001 XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002 XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004 @@ -26,7 +29,8 @@ XNN_FLAG_YIELD_WORKERS = 0x00000010 XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020 XNN_FLAG_KEEP_DIMS = 0x00000040 -XNN_EXTRA_QUANTIZATION_PARAMS = 8 +XNN_EXTRA_QUANTIZATION_PARAMS = 10 +XNN_MIN_BLOCKSIZE = 32 XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001 XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002 XNN_VALUE_FLAG_PERSISTENT = 0x00000004 diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 4548de4940a..84cdfd69a48 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -31,6 +31,7 @@ XNN_VALUE_FLAG_EXTERNAL_INPUT, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, ) +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.backend_details import ( BackendDetails, @@ -103,7 +104,7 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: - + named_data_store = NamedDataStore() xnnpack_edge_compile_config = get_xnnpack_edge_compile_config() # Need to wrap EP here because xnnpack does addmm to linear @@ -162,7 +163,7 @@ def preprocess( ) constant_data_bytes = bytearray() - node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) + node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store) for node in graph_module.graph.nodes: if node.op == "call_function": @@ -191,4 +192,5 @@ def preprocess( xnnpack_graph, constant_data_bytes ), debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), ) diff --git a/extension/testing_util/targets.bzl b/extension/testing_util/targets.bzl index 2b12480dfff..95b1f94d182 100644 --- a/extension/testing_util/targets.bzl +++ b/extension/testing_util/targets.bzl @@ -17,5 +17,6 @@ def define_common_targets(): "//executorch/extension/fb/ptez/decompression_methods/test/...", "//executorch/extension/fb/ptez/test/...", "//executorch/runtime/executor/test/...", + "//executorch/backends/xnnpack/test/...", ], ) diff --git a/schema/targets.bzl b/schema/targets.bzl index 40c6d8d5c8d..c0036c7500a 100644 --- a/schema/targets.bzl +++ b/schema/targets.bzl @@ -78,6 +78,10 @@ def define_common_targets(): # //executorch/runtime/executor/... "//executorch/codegen/tools/...", "//executorch/runtime/executor/...", + # Tests have a set up which uses raw flatbuffer. + # TODO will refactor these setup steps into + # testing utils in runtime/executor/... path + "//executorch/backends/xnnpack/test/...", ], exported_headers = { OUTPUT_PROGRAM_HEADER: ":{}[{}]".format(PROGRAM_GEN_RULE_NAME, OUTPUT_PROGRAM_HEADER),