diff --git a/examples/portable/scripts/export.py b/examples/portable/scripts/export.py index 353d8a034e0..0e25bd3f8ec 100644 --- a/examples/portable/scripts/export.py +++ b/examples/portable/scripts/export.py @@ -6,6 +6,8 @@ # Example script for exporting simple models to flatbuffer +# pyre-unsafe + import argparse import logging @@ -48,6 +50,15 @@ def main() -> None: required=False, help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", ) + + parser.add_argument( + "-e", + "--external_constants", + action=argparse.BooleanOptionalAction, + default=False, + help="Save constants in external .ptd file. Default is False", + ) + parser.add_argument("-o", "--output_dir", default=".", help="output directory") args = parser.parse_args() @@ -62,7 +73,7 @@ def main() -> None: *MODEL_NAME_TO_MODEL[args.model_name] ) - backend_config = ExecutorchBackendConfig() + backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) if args.segment_alignment is not None: backend_config.segment_alignment = int(args.segment_alignment, 16) if dynamic_shapes is not None: diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 24865e7a841..eb699c53d99 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -89,3 +89,7 @@ class ExecutorchBackendConfig: # If set to true, view_copy operations will be converted to lightweight # view operations in the ET runtime remove_view_copy: bool = True + + # If set to true, all constant tensors will be stored in a separate file, + # external to the PTE file. + external_constants: bool = False diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 1ebe9b2224d..fc10c1db66f 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1633,3 +1633,30 @@ def forward(self, x): input = torch.zeros(1) executorch_model(input) self.assertEqual(input, torch.ones(1)) + + def test_constant_tagged_tensors(self) -> None: + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + model = to_edge(export(LinearModule(), (torch.ones(5, 5),))).to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, + ) + ) + emitter_output = model._emitter_output + # Check that constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # Check that constant weights are in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 2) + # Setting external_constants=True, saves all constants to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(external_map["linear.weight"], 0) + self.assertEqual(external_map["linear.bias"], 1) diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index a3251589ac0..cda07a9423e 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -10,6 +10,7 @@ python_library( deps = [ ":const_prop_pass", ":debug_handle_generator_pass", + ":external_constants_pass", ":insert_write_back_for_buffers_pass", ":memory_format_ops_pass", ":memory_planning_pass", @@ -54,6 +55,16 @@ python_library( ], ) +python_library( + name = "external_constants_pass", + srcs = [ + "external_constants_pass.py", + ], + deps = [ + "//caffe2:torch", + ], +) + python_library( name = "insert_write_back_for_buffers_pass", srcs = [ diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py new file mode 100644 index 00000000000..1429e15cbb1 --- /dev/null +++ b/exir/passes/external_constants_pass.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch +from executorch.exir.tensor import TensorSpec +from torch.export.exported_program import ExportedProgram + + +def external_constants_pass( + ep: ExportedProgram, +) -> ExportedProgram: + """ + Move all constants to external file. + """ + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if node.op == "placeholder": + spec = node.meta.get("spec") + if isinstance(spec, TensorSpec) and spec.const: + node.meta["constant_tag"] = "_default_external_constant" + return ep diff --git a/exir/program/_program.py b/exir/program/_program.py index 0521088dfdb..fd1d0aca3dc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -33,6 +33,7 @@ MemoryFormatOpsPass, OpReplacePass, ) +from executorch.exir.passes.external_constants_pass import external_constants_pass from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -1380,6 +1381,9 @@ def to_executorch( ) else: new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] + + if config.external_constants: + new_gm_res = external_constants_pass(new_gm_res) assert new_gm_res is not None new_gm = new_gm_res.graph_module