Skip to content

Commit b5344c1

Browse files
authored
Use to_edge_lower_and_transform for XNNPack (#8624)
1 parent 9484c01 commit b5344c1

File tree

3 files changed

+113
-48
lines changed

3 files changed

+113
-48
lines changed

examples/models/llama/export_llama_lib.py

+91-41
Original file line numberDiff line numberDiff line change
@@ -676,47 +676,62 @@ def _validate_args(args):
676676
)
677677

678678

679-
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
680-
_validate_args(args)
681-
682-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
683-
684-
# export_to_edge
685-
builder_exported = _prepare_for_llama_export(args).export()
686-
687-
builder_exported.run_canonical_optimizations()
688-
689-
if args.export_only:
690-
exit()
691-
692-
builder_exported_to_edge = builder_exported.pt2e_quantize(
693-
quantizers
694-
).export_to_edge()
695-
696-
modelname = builder_exported_to_edge.modelname
697-
698-
# to_backend
679+
def _to_edge_and_lower_llama_xnnpack(
680+
builder_exported,
681+
modelname,
682+
additional_passes,
683+
pt2e_quant_params,
684+
quantizers,
685+
quant_dtype,
686+
args,
687+
) -> LLMEdgeManager: # noqa: C901
699688
partitioners = []
700689

701690
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702-
if (
703-
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
704-
) or (args.xnnpack):
705-
partitioners.append(
706-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
707-
)
691+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
708692

709-
# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710-
args.xnnpack = True
711-
modelname = f"xnnpack_dq_{modelname}"
693+
modelname = f"xnnpack_dq_{modelname}"
712694

713695
if args.xnnpack_extended_ops:
714-
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
715696
partitioners.append(
716697
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
717698
)
718699
modelname = f"xnnpack_{modelname}"
719700

701+
logging.info("Lowering model using following partitioner(s): ")
702+
for partitioner in partitioners:
703+
logging.info(f"--> {partitioner.__class__.__name__}")
704+
705+
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706+
if args.generate_etrecord:
707+
raise NotImplementedError(
708+
"export_llama does not support XNNPack and generating ETRecord at the moment."
709+
)
710+
711+
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
712+
partitioners
713+
)
714+
if args.verbose:
715+
print_delegation_info(builder.edge_manager.exported_program().graph_module)
716+
717+
return builder.to_executorch(passes=additional_passes)
718+
719+
720+
def _to_edge_and_lower_llama( # noqa: C901
721+
builder_exported,
722+
modelname,
723+
additional_passes,
724+
pt2e_quant_params,
725+
quantizers,
726+
quant_dtype,
727+
args,
728+
):
729+
builder_exported_to_edge = builder_exported.pt2e_quantize(
730+
quantizers
731+
).export_to_edge()
732+
733+
# to_backend
734+
partitioners = []
720735
if args.vulkan:
721736
partitioners.append(
722737
get_vulkan_partitioner(
@@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
731746
modelname = f"vulkan_{modelname}"
732747

733748
# Need to remove asserts from the graph to prevent graph breaks
734-
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
735749
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
736750

737751
if args.mps:
@@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760774
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
761775
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
762776

763-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
764777
_transform(builder_exported_to_edge.edge_manager.exported_program())
765778

766779
if args.num_sharding > 0:
767780
model_sharding.split_graph(
768781
builder_exported_to_edge.edge_manager.exported_program(),
769-
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
770782
builder_exported_to_edge.metadata["get_n_layers"],
771783
shares=args.num_sharding,
772784
)
@@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
792804
atten.head_dim,
793805
)
794806
)
795-
# pyre-ignore
796807
tag_quant_io(
797808
builder_exported_to_edge.edge_manager.exported_program().graph_module,
798-
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
809+
partial(get_custom_quant_ios_dtype, cache_shape),
799810
)
800811

801812
logging.info("Lowering model using following partitioner(s): ")
802813
for partitioner in partitioners:
803814
logging.info(f"--> {partitioner.__class__.__name__}")
804815

805-
additional_passes = []
806-
if args.model in TORCHTUNE_DEFINED_MODELS:
807-
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
808816
if args.generate_etrecord:
809817
if not builder_exported_to_edge.edge_manager:
810818
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
818826
if args.num_sharding > 0 and args.qnn:
819827
from executorch.backends.qualcomm.utils.utils import canonicalize_program
820828

821-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
822829
canonicalize_program(builder.edge_manager.exported_program())
823830

824831
builder = builder.to_executorch(
@@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
840847
if args.num_sharding > 0 and args.qnn:
841848
from executorch.backends.qualcomm.utils.utils import canonicalize_program
842849

843-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
844850
canonicalize_program(builder.edge_manager.exported_program())
845851

846852
builder = builder.to_executorch(passes=additional_passes)
847853

854+
return builder
855+
856+
857+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
858+
_validate_args(args)
859+
860+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
861+
862+
additional_passes = []
863+
if args.model in TORCHTUNE_DEFINED_MODELS:
864+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
865+
866+
# export_to_edge
867+
builder_exported = _prepare_for_llama_export(args).export()
868+
builder_exported.run_canonical_optimizations()
869+
modelname = builder_exported.modelname
870+
871+
if args.export_only:
872+
exit()
873+
874+
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
875+
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876+
args.xnnpack = True
877+
878+
if args.xnnpack:
879+
builder = _to_edge_and_lower_llama_xnnpack(
880+
builder_exported,
881+
modelname,
882+
additional_passes,
883+
pt2e_quant_params,
884+
quantizers,
885+
quant_dtype,
886+
args,
887+
)
888+
else:
889+
builder = _to_edge_and_lower_llama(
890+
builder_exported,
891+
modelname,
892+
additional_passes,
893+
pt2e_quant_params,
894+
quantizers,
895+
quant_dtype,
896+
args,
897+
)
898+
848899
if args.profile_memory:
849900
generate_memory_trace(builder.export_program, "memory_profile.json")
850901

@@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
866917
output_file = f"{builder.output_dir}/{modelname}.pte"
867918

868919
builder.save_to_pte(output_file)
869-
870920
return builder
871921

872922

examples/models/llava/export_llava.py

-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def export(self) -> "LlavaEdgeManager":
6767
dynamic_shapes=dynamic_shape,
6868
strict=False,
6969
)
70-
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
7170
self.pre_autograd_graph_module = self.export_program.module()
7271
return self
7372

extension/llm/export/builder.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DuplicateDynamicQuantChainPass,
2222
)
2323
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
24-
from executorch.exir import EdgeProgramManager
24+
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2525
from executorch.exir.backend.partitioner import Partitioner
2626

2727
from executorch.exir.backend.utils import format_delegated_graph
@@ -39,7 +39,7 @@
3939
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torch.ao.quantization.quantizer import Quantizer
4141
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
42-
from torch.export import export_for_training
42+
from torch.export import export_for_training, ExportedProgram
4343
from torch.nn.attention import SDPBackend
4444

4545
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -89,8 +89,8 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
# graph module returned from export()
93-
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
92+
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
93+
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
9494
self.modelname = modelname
9595
self.max_seq_len = max_seq_len
9696
self.dtype = dtype
@@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager":
218218
kwargs=self.example_kwarg_inputs,
219219
dynamic_shapes=dynamic_shape,
220220
)
221-
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
222221
# `Module`.
222+
self.pre_autograd_exported_program = exported_module
223223
self.pre_autograd_graph_module = exported_module.module()
224224
if hasattr(self.args, "export_only") and self.args.export_only:
225225
torch.export.save(exported_module, self.args.output_name)
@@ -330,7 +330,10 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
330330
assert (
331331
self.pre_autograd_graph_module is not None
332332
), "Please run export() first"
333-
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
333+
m = prepare_pt2e(
334+
self.pre_autograd_graph_module, # pyre-ignore[6]
335+
composed_quantizer,
336+
)
334337
logging.info(
335338
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
336339
)
@@ -430,6 +433,19 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
430433

431434
return self
432435

436+
def to_edge_transform_and_lower(
437+
self, partitioners: Optional[List[Partitioner]]
438+
) -> "LLMEdgeManager":
439+
if partitioners is None:
440+
logging.info("No partitioner provided, skipping backend lowering...")
441+
edge_config = self._get_edge_config()
442+
self.edge_manager = to_edge_transform_and_lower(
443+
self.pre_autograd_exported_program,
444+
partitioner=partitioners,
445+
compile_config=edge_config,
446+
)
447+
return self
448+
433449
def to_executorch(
434450
self, passes: Optional[List[ExportPass]] = None
435451
) -> "LLMEdgeManager":

0 commit comments

Comments
 (0)