Skip to content

Commit 44d9570

Browse files
committed
Update on "[executorch][runtime] Introduce PteDataMap for weight sharing"
PteDataMap is the NamedDataMap that will live in the runtime. It is used to give delegates access to opaque named data stored in the PTE file. Open to alternative naming suggestions, maybe 'PTEDataMap' or 'ProgramDataMap'? **Usage** The PteDataMap is owned by the program, and instantiated at program load time if named_data exists in the PTE file. We introduce usage of 'std::optional' here. I think we can also use executorch::aten::optional to avoid adding standard lib ? When initializing delegates, the PteDataMap is given to delegate_init. Delegates can retrieve opaque delegate data by key using 'get_data'. This gives them a FreeableBuffer that they can free later. **Testing** This test uses the C++ flatbuffer API to build a fake program containing named data. We also creates a temp file with sample data that the data loader can wrap around. TODO: e2e test once delegate aot is ready and we can generate a file with named data. **Note** As the PteDataMap wraps around flatbuffer constructs, the Program must outlive the PteDataMap. PteDataMap does not implement - get_metadata; currently, all data stored is opaque. Later, we can implement get_metadata if a backend stores plain tensor data. - load_into; this is mostly used for the training case, and isn't used by delegates, at least not at the moment Differential Revision: [D70213646](https://our.internmc.facebook.com/intern/diff/D70213646/) [ghstack-poisoned]
2 parents 6c4f055 + eb2e2fc commit 44d9570

File tree

63 files changed

+1028
-1449
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1028
-1449
lines changed

.ci/scripts/test_ane_static_llama.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
11+
12+
export EXECUTORCH_ROOT="$(dirname "${BASH_SOURCE[0]}")/../.."
13+
14+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
15+
PYTHON_EXECUTABLE=python3
16+
fi
17+
18+
which "${PYTHON_EXECUTABLE}"
19+
20+
pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
21+
22+
# Download stories llama110m artifacts
23+
download_stories_model_artifacts
24+
25+
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
26+
27+
popd

.ci/scripts/test_model.sh

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ test_model() {
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103+
if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then
104+
# Install requirements for export_llama
105+
bash examples/models/llama/install_requirements.sh
106+
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json
108+
run_portable_executor_runner
109+
rm "./${MODEL_NAME}.pte"
110+
fi
103111

104112
# Export a basic .pte and run the model.
105113
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"
@@ -164,6 +172,7 @@ test_model_with_qnn() {
164172
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
165173
export PYTHONPATH=$EXECUTORCH_ROOT/..
166174

175+
EXTRA_FLAGS=""
167176
if [[ "${MODEL_NAME}" == "dl3" ]]; then
168177
EXPORT_SCRIPT=deeplab_v3
169178
elif [[ "${MODEL_NAME}" == "mv3" ]]; then
@@ -176,6 +185,12 @@ test_model_with_qnn() {
176185
EXPORT_SCRIPT=inception_v3
177186
elif [[ "${MODEL_NAME}" == "vit" ]]; then
178187
EXPORT_SCRIPT=torchvision_vit
188+
elif [[ "${MODEL_NAME}" == "mb" ]]; then
189+
EXPORT_SCRIPT=mobilebert_fine_tune
190+
EXTRA_FLAGS="--num_epochs 1"
191+
pip install scikit-learn
192+
elif [[ "${MODEL_NAME}" == "w2l" ]]; then
193+
EXPORT_SCRIPT=wav2letter
179194
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
180195
EXPORT_SCRIPT=edsr
181196
# Additional deps for edsr
@@ -189,7 +204,7 @@ test_model_with_qnn() {
189204
# TODO(guangyang): Make QNN chipset matches the target device
190205
QNN_CHIPSET=SM8450
191206

192-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only
207+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only $EXTRA_FLAGS
193208
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
194209
}
195210

.github/workflows/trunk.yml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,28 @@ jobs:
229229
# see if we can import the module successfully
230230
${CONDA_RUN} python -c "from executorch.extension.pybindings import portable_lib; print('success!')"
231231
232+
test-static-llama-ane:
233+
name: test-static-llama-ane
234+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
235+
with:
236+
runner: macos-m1-stable
237+
python-version: '3.11'
238+
submodules: 'true'
239+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
240+
script: |
241+
set -eux
242+
bash .ci/scripts/setup-conda.sh
243+
eval "$(conda shell.bash hook)"
244+
245+
# Install requirements
246+
sh install_requirements.sh
247+
sh backends/apple/coreml/scripts/install_requirements.sh
248+
python install_executorch.py --pybind coreml
249+
sh examples/models/llama/install_requirements.sh
250+
251+
# Test ANE llama
252+
sh .ci/scripts/test_ane_static_llama.sh
253+
232254
test-llama-runner-macos:
233255
name: test-llama-runner-mac
234256
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
@@ -311,7 +333,7 @@ jobs:
311333
strategy:
312334
matrix:
313335
dtype: [fp32]
314-
model: [dl3, mv3, mv2, ic4, ic3, vit]
336+
model: [dl3, mv3, mv2, ic4, ic3, vit, mb, w2l]
315337
fail-fast: false
316338
with:
317339
runner: linux.2xlarge

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ python_library(
99
"//executorch/backends/transforms:replace_scalar_with_tensor",
1010
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1111
"//executorch/exir:lib",
12+
"//executorch/backends/transforms:utils",
1213
],
1314
)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.transforms.utils import (
10+
create_constant_placeholder,
11+
delete_constant_placeholder,
12+
)
913
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass, PassResult
1216
from torch._export.utils import get_buffer, get_param
17+
from torch.export.graph_signature import InputKind
1318
from torch.fx import Node
1419
from torch.nn.utils.fusion import fuse_conv_bn_weights
1520

@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
2328
self.exported_program = exported_program
2429
super().__init__()
2530

26-
def is_fuseable_conv_bn(self, node: Node):
31+
def is_fuseable_conv_bn(self, node: Node) -> bool:
2732
"""Returns True if node is a batchnorm that can be fused into
2833
a parent convolution."""
2934
if node.op != "call_function":
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
4449
# Since we change the output of the conv, fuse only if it has single user.
4550
if len(conv.users) > 1:
4651
return False
47-
# For similar reasons, only fuse if conv parameters have single user.
48-
if len(conv.all_input_nodes[1].users) > 1:
49-
return False
50-
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
51-
return False
5252
return True
5353

54+
def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
55+
if conv_bias_node:
56+
return conv_bias_node.name + "_fused_bn"
57+
elif "weight" in conv_weight_node.name:
58+
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
59+
else:
60+
return conv_weight_node.name + "_bias_fused_bn"
61+
5462
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
5563
modified = False
64+
constant_placeholders_to_delete = set()
5665
for node in graph_module.graph.nodes:
5766
if not self.is_fuseable_conv_bn(node):
5867
continue
@@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
6473
)
6574

6675
# Get weight, bias, mean, var and epsilon from the batchnorm
67-
bn = node
68-
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
69-
bn_weight = get_param_or_none(bn_weight_node)
70-
bn_bias = get_param_or_none(bn_bias_node)
71-
72-
running_mean = get_buffer(self.exported_program, bn_mean_node)
73-
running_var = get_buffer(self.exported_program, bn_var_node)
74-
if running_mean is None or running_var is None:
76+
bn_node = node
77+
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
78+
bn_node.args[0:5]
79+
)
80+
bn_weight_tensor = get_param_or_none(bn_weight_node)
81+
bn_bias_tensor = get_param_or_none(bn_bias_node)
82+
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
83+
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
84+
if bn_mean_tensor is None or bn_var_tensor is None:
7585
raise ValueError(
7686
"Parameters running_mean and running_var of batchnorm can't be None."
7787
)
78-
epsilon = bn.args[-1]
88+
epsilon = bn_node.args[-1]
7989

8090
# Get weight and bias from conv
8191
conv_weight_node, conv_bias_node = conv.args[1:3]
82-
conv_weight = get_param(self.exported_program, conv_weight_node)
83-
conv_bias = get_param_or_none(conv_bias_node)
84-
if conv_weight is None:
92+
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
93+
conv_bias_tensor = get_param_or_none(conv_bias_node)
94+
if conv_weight_tensor is None:
8595
raise ValueError("Parameter weight of convolution can't be None.")
8696

8797
# Compute conv parameters folded with batchnorm
8898
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
89-
conv_weight,
90-
conv_bias,
91-
running_mean,
92-
running_var,
99+
conv_weight_tensor,
100+
conv_bias_tensor,
101+
bn_mean_tensor,
102+
bn_var_tensor,
93103
epsilon,
94-
bn_weight,
95-
bn_bias,
104+
bn_weight_tensor,
105+
bn_bias_tensor,
96106
)
97107

98-
# Set the conv parameters to fused value
99-
def try_set_param(
100-
param_node: Node | None, param_value: torch.nn.Parameter
101-
) -> bool:
102-
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103-
if param_node is not None:
104-
param_name = (
105-
self.exported_program.graph_signature.inputs_to_parameters[
106-
param_node.name
107-
]
108+
# Create fused weights and bias to conv and replace conv args
109+
with graph_module.graph.inserting_before(conv_weight_node):
110+
fused_conv_weight_node = create_constant_placeholder(
111+
exp_program=self.exported_program,
112+
graph=graph_module.graph,
113+
kind=InputKind.PARAMETER,
114+
name=conv_weight_node.name + "_fused_bn",
115+
data=fused_conv_weight,
116+
)
117+
118+
if fused_conv_bias is not None:
119+
fused_conv_bias_node = create_constant_placeholder(
120+
exp_program=self.exported_program,
121+
graph=graph_module.graph,
122+
kind=InputKind.PARAMETER,
123+
name=self.get_bias_name(conv_weight_node, conv_bias_node),
124+
data=fused_conv_bias,
108125
)
109-
self.exported_program.state_dict[param_name] = param_value
110-
return True
111-
return False
126+
else:
127+
fused_conv_bias_node = None
128+
129+
conv.args = (
130+
conv.args[0],
131+
fused_conv_weight_node,
132+
fused_conv_bias_node,
133+
*conv.args[3:],
134+
)
112135

113-
try_set_param(conv_weight_node, fused_conv_weight)
114-
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115-
bn_bias_node, fused_conv_bias
116-
):
117-
# pyre-ignore[60]
118-
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
119-
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
120-
conv.args = conv_args
121-
122-
# Erasing nodes is handled by dead-code elimination.
123-
for user in bn.users:
136+
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137+
for user in bn_node.users:
124138
user.replace_all_uses_with(conv)
139+
140+
constant_placeholders_to_delete.update(
141+
[
142+
bn_weight_node,
143+
bn_bias_node,
144+
bn_mean_node,
145+
bn_var_node,
146+
conv_weight_node,
147+
conv_bias_node,
148+
]
149+
)
125150
modified = True
126151

127152
if modified:
128153
graph_module.graph.eliminate_dead_code()
154+
for constant_placeholder in constant_placeholders_to_delete:
155+
if (constant_placeholder is not None) and (
156+
len(constant_placeholder.users) == 0
157+
):
158+
delete_constant_placeholder(
159+
self.exported_program, constant_placeholder
160+
)
161+
129162
graph_module.recompile()
130163
graph_module = super().call(graph_module).graph_module
164+
131165
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/test/models/test_w2l_arm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def test_w2l_u55_BI(self):
131131

132132
@pytest.mark.slow
133133
@pytest.mark.corstone_fvp
134-
@unittest.skip("Blocked by MLBEDSW-10420")
135134
@conftest.expectedFailureOnFVP # TODO: MLBEDSW-10093
136135
def test_w2l_u85_BI(self):
137136
tester = self._test_w2l_ethos_BI_pipeline(

backends/arm/test/passes/test_fuse_batchnorm_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ def forward(self, x):
8585
return x
8686

8787

88-
class MergeNoBN(torch.nn.Module):
88+
class MergeMultipleUsersBN(torch.nn.Module):
8989
ops_before_pass = {
9090
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
9191
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9292
}
9393
ops_after_pass = {
94-
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
94+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
9595
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9696
}
9797

@@ -122,7 +122,7 @@ def forward(self, x):
122122
z = self.conv2d2(x)
123123
a = self.batch_norm2d(
124124
y
125-
) # Can't be fused since paramters of conv2d2 have multiple users.
125+
) # Can be fused despite paramters of conv2d2 having multiple users.
126126

127127
return z, a
128128

@@ -131,7 +131,7 @@ def forward(self, x):
131131
"merge_one_of_two_bn_affine": MergeOneOfTwoBN(True),
132132
"merge_one_of_two_bn": MergeOneOfTwoBN(False),
133133
"merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True),
134-
"merge_no_bn_affine": MergeNoBN(True),
134+
"merge_multiple_users_bn_affine": MergeMultipleUsersBN(True),
135135
}
136136

137137

backends/cadence/aot/pass_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
104104
return total
105105

106106

107+
def op_counts_match(
108+
graph_module: torch.fx.GraphModule,
109+
expected_op_counts: dict[EdgeOpOverload, int],
110+
) -> bool:
111+
for op, count in expected_op_counts.items():
112+
if count_node(graph_module, op) != count:
113+
return False
114+
return True
115+
116+
107117
# Testing utils
108118
# Return the compute/function nodes in the graph
109119
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:

0 commit comments

Comments
 (0)