From e8db57da7e30ad6acc7f31f965cc645a633bf702 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 11 Mar 2025 12:54:24 -0700 Subject: [PATCH] [ExecuTorch][Weight Sharing][XNNPACK] Serialize constant tensors into named data map We serialize tensors into the named data map, and return the output in preprocess result. Allowing for XNNPACK to share tensors with the same name (instead of duplicating). A key change here is with fused tensors. For BN and Convolution Fusion, we fuse the conv weights and bias with the BN parameters creating new tensors. We then create get_attr nodes for these new parameters. Due to the graph.fx interpreter in export pass base, the new names we create for these new tensors are lost each time. As a result, at the end we introduce a new pass to preserve the names we created. This seems a little hacky for now, but is the only way to preserve the new fused names. Differential Revision: [D70315207](https://our.internmc.facebook.com/intern/diff/D70315207/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D70315207/)! [ghstack-poisoned] --- backends/xnnpack/_passes/TARGETS | 1 + .../_passes/fuse_batch_norm_with_conv.py | 64 ++++++++++++++----- backends/xnnpack/operators/node_visitor.py | 20 ++++-- backends/xnnpack/serialization/schema.fbs | 9 +++ .../serialization/xnnpack_graph_schema.py | 1 + .../xnnpack/utils/gen_xnnpack_constants.sh | 1 + backends/xnnpack/utils/utils.py | 18 ++++++ backends/xnnpack/utils/xnnpack_constants.py | 6 +- backends/xnnpack/xnnpack_preprocess.py | 6 +- 9 files changed, 100 insertions(+), 26 deletions(-) diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index a199e1aab01..972980570ec 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -19,5 +19,6 @@ python_library( "//executorch/exir/passes:const_prop_pass", "//executorch/exir/passes:memory_format_ops_pass", "//executorch/exir/program:program", + "//executorch/backends/transforms:utils", ], ) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py index b0f4779eb4c..f973510deb2 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py @@ -9,8 +9,13 @@ import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from torch.export.graph_signature import InputKind -from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node +from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node, get_tensor_name from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult @@ -28,7 +33,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - counter = 0 + constant_placeholders_to_delete = set() for conv in graph.nodes: # We want to discover a chain of conv -> batch_norm. # Only proceed if the current node is a conv node, and has a single @@ -55,9 +60,11 @@ def call(self, graph_module: torch.fx.GraphModule): assert len(conv.args) == 9 conv_weight = get_param_tensor(self.exported_program, conv.args[1]) + conv_weight_name = get_tensor_name(self.exported_program, conv.args[1]) assert conv_weight is not None conv_bias = get_param_tensor(self.exported_program, conv.args[2]) + conv_bias_name = get_tensor_name(self.exported_program, conv.args[2]) # Get the parameters from the batchnorm op assert ( @@ -95,23 +102,39 @@ def call(self, graph_module: torch.fx.GraphModule): bn_bias, is_transpose, ) + fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_") + fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_") # Modify the graph by updating the weight and bias of conv op # with the fused weight and bias params, and replacing all the users # of getitem(batchnorm) with the conv op. - with graph.inserting_before(conv): - fused_weight_name = f"_fused_with_bn_weight_{counter}" - graph_module.register_parameter(fused_weight_name, fused_weight) - fused_weight_node = graph.get_attr(fused_weight_name) - fused_bias_name = f"_fused_with_bn_bias_{counter}" - graph_module.register_parameter(fused_bias_name, fused_bias) - fused_bias_node = graph.get_attr(fused_bias_name) - - # Update the weight and bias of conv op - conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else []) - conv_args[1] = fused_weight_node - conv_args[2] = fused_bias_node - conv.args = tuple(conv_args) + with graph.inserting_before(conv.args[1]): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight + ) + if fused_bias is not None: + fused_conv_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias + ) + else: + fused_conv_bias_node = None + + conv.args = ( + conv.args[0], + fused_conv_weight_node, + fused_conv_bias_node, + *conv.args[3:] + ) + + # Remove any use of batchnorm from the graph for user in bn.users.copy(): assert user.target == operator.getitem @@ -119,8 +142,17 @@ def call(self, graph_module: torch.fx.GraphModule): graph.erase_node(user) graph.erase_node(bn) + constant_placeholders_to_delete.update( + conv.args[1:3] + bn.args[1:5] + ) - counter += 1 + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and ( + len(node.users) == 0 + ): + delete_constant_placeholder(self.exported_program, node) graph_module.recompile() # To Regenerate meta data and shape information, retrace module diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 0a825a94bef..099adb72640 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -16,6 +16,7 @@ ) from executorch.backends.xnnpack.operators.quant_params import QuantParams +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( ConstantDataOffset, @@ -30,11 +31,15 @@ XNNTensorValue, XValue, ) +from executorch.backends.xnnpack.utils.xnnpack_constants import ( + UINT64_MAX +) from executorch.backends.xnnpack.utils.utils import ( check_or_raise, get_input_node, get_param_tensor, is_param_node, + get_tensor_name, PERM_NCHW_TO_NHWC, ) @@ -86,11 +91,11 @@ def __init__( self, exported_program: ExportedProgram, external_ids: Dict, - constant_data_bytes: bytearray, + named_data_store: NamedDataStore, ) -> None: self._external_ids = external_ids or {} self._exported_program = exported_program or None - self._constant_data_bytes = constant_data_bytes + self._named_data_store = named_data_store @property def external_ids(self) -> Dict: @@ -579,12 +584,13 @@ def get_serialized_buffer_index( ctypes.POINTER(array_type), ).contents - offset = len(self._constant_data_bytes) + named_key = get_tensor_name(self.exported_program, get_attr_node) + if named_key == "": + raise ValueError(f"Tensor from node: {get_attr_node} has no name") + size = const_val.untyped_storage().nbytes() - xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size)) - self._constant_data_bytes.extend( - _pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT)) - ) + xnn_graph.constant_data.append(ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)) + self._named_data_store.add_named_data(named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT) return buffer_idx diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 5a43481b98d..193656c30b1 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -316,11 +316,20 @@ table XNNLeakyReLU { table ConstantDataOffset { // Constant data offsets are relative to the constant data base offset provided // in the XNNPACKHeader. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string offset: uint64; // The size in bytes of valid data starting at the offset. The constant data // may be followed by padding before the next piece of constant data size: uint64; + + // unique string id used to query the offset from the named data store. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string + named_key: string; } table XNNGraph { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 3276dac7869..3cb572c66ef 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -470,6 +470,7 @@ class XValue: class ConstantDataOffset: offset: int size: int + named_key: str = "" @dataclass diff --git a/backends/xnnpack/utils/gen_xnnpack_constants.sh b/backends/xnnpack/utils/gen_xnnpack_constants.sh index 6be9d4519f3..5fa92e5b038 100644 --- a/backends/xnnpack/utils/gen_xnnpack_constants.sh +++ b/backends/xnnpack/utils/gen_xnnpack_constants.sh @@ -26,5 +26,6 @@ } > xnnpack_constants.py echo UINT32_MAX = 4294967295 >> xnnpack_constants.py +echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b802d73c16b..ed81b99279e 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -131,6 +131,24 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") +def get_tensor_name( + exp_prog: ExportedProgram, node: torch.fx.Node +) -> str: + if node is None: + return "" + if is_param(exp_prog, node): + return exp_prog.graph_signature.inputs_to_parameters[node.name] + elif is_buffer(exp_prog, node): + return exp_prog.graph_signature.inputs_to_buffers[node.name] + elif is_lifted_tensor_constant(exp_prog, node): + return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name] + else: + assert(isinstance(node.target, str)) + return node.target + + return "" + + def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: """ Returns the source fn of the given node, return None if something goes wrong diff --git a/backends/xnnpack/utils/xnnpack_constants.py b/backends/xnnpack/utils/xnnpack_constants.py index 351cc8ad897..364819a2435 100644 --- a/backends/xnnpack/utils/xnnpack_constants.py +++ b/backends/xnnpack/utils/xnnpack_constants.py @@ -6,8 +6,11 @@ # Auto-generated by gen_xnnpack_constants.sh script. Do not modify UINT32_MAX = 4294967295 +UINT64_MAX = 18446744073709551615 +XNN_EXTRA_BYTES = 128 XNN_EXTRA_BYTES = 16 XNN_MAX_TENSOR_DIMS = 6 +XNN_INVALID_VALUE_ID = UINT32_MAX XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001 XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002 XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004 @@ -26,7 +29,8 @@ XNN_FLAG_YIELD_WORKERS = 0x00000010 XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020 XNN_FLAG_KEEP_DIMS = 0x00000040 -XNN_EXTRA_QUANTIZATION_PARAMS = 8 +XNN_EXTRA_QUANTIZATION_PARAMS = 10 +XNN_MIN_BLOCKSIZE = 32 XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001 XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002 XNN_VALUE_FLAG_PERSISTENT = 0x00000004 diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 4548de4940a..869ece42689 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -38,6 +38,7 @@ PreprocessResult, ) from executorch.exir.verification.verifier import EXIREdgeDialectVerifier +from executorch.exir._serialize._named_data_store import NamedDataStore from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 @@ -103,7 +104,7 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: - + named_data_store = NamedDataStore() xnnpack_edge_compile_config = get_xnnpack_edge_compile_config() # Need to wrap EP here because xnnpack does addmm to linear @@ -162,7 +163,7 @@ def preprocess( ) constant_data_bytes = bytearray() - node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) + node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store) for node in graph_module.graph.nodes: if node.op == "call_function": @@ -191,4 +192,5 @@ def preprocess( xnnpack_graph, constant_data_bytes ), debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), )