Skip to content

Commit 2d07744

Browse files
committed
[executorch][passes] Add config and pass to tag constants for external file
Pull Request resolved: #7193 - Add config 'external_constants' to ExecutorchBackendConfig. - When set to True, run the 'external_constants_pass' - This tags all constants as external, and moves them into a separate buffer to be serialized outside of the PTE file. Note: users can write their own passes to tag weights to specific files / multiple files. TODO: write example pass and test for the case where we have two constant files. ghstack-source-id: 257349683 Differential Revision: [D66560903](https://our.internmc.facebook.com/intern/diff/D66560903/)
1 parent 37363c9 commit 2d07744

File tree

6 files changed

+85
-1
lines changed

6 files changed

+85
-1
lines changed

examples/portable/scripts/export.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ def main() -> None:
4848
required=False,
4949
help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS",
5050
)
51+
52+
parser.add_argument(
53+
"-e",
54+
"--external_constants",
55+
action=argparse.BooleanOptionalAction,
56+
default=False,
57+
help="Save constants in external .ptd file. Default is False",
58+
)
59+
5160
parser.add_argument("-o", "--output_dir", default=".", help="output directory")
5261

5362
args = parser.parse_args()
@@ -62,7 +71,7 @@ def main() -> None:
6271
*MODEL_NAME_TO_MODEL[args.model_name]
6372
)
6473

65-
backend_config = ExecutorchBackendConfig()
74+
backend_config = ExecutorchBackendConfig(external_constants=args.external_constants)
6675
if args.segment_alignment is not None:
6776
backend_config.segment_alignment = int(args.segment_alignment, 16)
6877
if dynamic_shapes is not None:

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ class ExecutorchBackendConfig:
8989
# If set to true, view_copy operations will be converted to lightweight
9090
# view operations in the ET runtime
9191
remove_view_copy: bool = True
92+
93+
# If set to true, all constant tensors will be stored in a separate file,
94+
# external to the PTE file.
95+
external_constants: bool = False

exir/emit/test/test_emit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,3 +1633,30 @@ def forward(self, x):
16331633
input = torch.zeros(1)
16341634
executorch_model(input)
16351635
self.assertEqual(input, torch.ones(1))
1636+
1637+
def test_constant_tagged_tensors(self) -> None:
1638+
class LinearModule(torch.nn.Module):
1639+
def __init__(self):
1640+
super().__init__()
1641+
self.linear = torch.nn.Linear(5, 5)
1642+
1643+
def forward(self, x):
1644+
return self.linear(x)
1645+
1646+
model = to_edge(export(LinearModule(), (torch.ones(5, 5),))).to_executorch(
1647+
config=ExecutorchBackendConfig(
1648+
external_constants=True,
1649+
)
1650+
)
1651+
emitter_output = model._emitter_output
1652+
# Check that constant_buffer is empty besides the non-constant placeholder 0.
1653+
self.assertEqual(len(emitter_output.program.constant_buffer), 1)
1654+
# Check that constant weights are in the external constant buffer.
1655+
self.assertEqual(len(emitter_output.external_constant_buffer), 2)
1656+
# Setting external_constants=True, saves all constants to the key
1657+
# '_default_external_constant'.
1658+
external_map = emitter_output.external_constant_map[
1659+
"_default_external_constant"
1660+
]
1661+
self.assertEqual(external_map["linear.weight"], 0)
1662+
self.assertEqual(external_map["linear.bias"], 1)

exir/passes/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
deps = [
1111
":const_prop_pass",
1212
":debug_handle_generator_pass",
13+
":external_constants_pass",
1314
":insert_write_back_for_buffers_pass",
1415
":memory_format_ops_pass",
1516
":memory_planning_pass",
@@ -54,6 +55,16 @@ python_library(
5455
],
5556
)
5657

58+
python_library(
59+
name = "external_constants_pass",
60+
srcs = [
61+
"external_constants_pass.py",
62+
],
63+
deps = [
64+
"//caffe2:torch",
65+
],
66+
)
67+
5768
python_library(
5869
name = "insert_write_back_for_buffers_pass",
5970
srcs = [
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.tensor import TensorSpec
11+
from torch.export.exported_program import ExportedProgram
12+
13+
14+
def external_constants_pass(
15+
ep: ExportedProgram,
16+
) -> ExportedProgram:
17+
"""
18+
Move all constants to external file.
19+
"""
20+
for module in ep.graph_module.modules():
21+
if not isinstance(module, torch.fx.GraphModule):
22+
continue
23+
24+
for node in module.graph.nodes:
25+
if node.op == "placeholder":
26+
spec = node.meta.get("spec")
27+
if isinstance(spec, TensorSpec) and spec.const:
28+
node.meta["constant_tag"] = "_default_external_constant"
29+
return ep

exir/program/_program.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MemoryFormatOpsPass,
3434
OpReplacePass,
3535
)
36+
from executorch.exir.passes.external_constants_pass import external_constants_pass
3637
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3738
insert_write_back_for_buffers_pass,
3839
)
@@ -1380,6 +1381,9 @@ def to_executorch(
13801381
)
13811382
else:
13821383
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
1384+
1385+
if config.external_constants:
1386+
new_gm_res = external_constants_pass(new_gm_res)
13831387
assert new_gm_res is not None
13841388
new_gm = new_gm_res.graph_module
13851389

0 commit comments

Comments
 (0)