From affe92dda354a413615646e8828ee91964427259 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 10 Mar 2025 18:17:49 -0700 Subject: [PATCH 1/2] Fix pre-autograd transforms not getting persisted during xnnpack export --- extension/llm/export/builder.py | 49 ++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index eb8dd462378..6ac577bae59 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -89,7 +89,9 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - self.pre_autograd_exported_program: Optional[ExportedProgram] = None + self.exported_program: Optional[ExportedProgram] = None + # Self.exported_program's pre-autograd graph module, for running + # transform passes on the graph prior to torch.export(). self.pre_autograd_graph_module: Optional[torch.nn.Module] = None self.modelname = modelname self.max_seq_len = max_seq_len @@ -184,7 +186,21 @@ def _get_edge_config(self) -> EdgeCompileConfig: ) return edge_config - def export(self) -> "LLMEdgeManager": + def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager": + """ + Exports the model pre-autograd. This is not a full export, since it uses + torch.export_for_training() to keep autograd-safe ops from getting decomposed. + The full torch.export() if called later on during to_edge() or + to_edge_transform_and_lower(). + + The optional `module` argument is included so that the user can re-export + an already-exported module's ExportedProgram's graph module, to persiste + the changes into a new ExportedProgram. + + Args: + module (Optional[torch.nn.Module]): module to export. + + """ dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -201,25 +217,30 @@ def export(self) -> "LLMEdgeManager": # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details exported_module = torch.export.export( - self.model, + self.model if not module else module, self.example_inputs, self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, strict=True, ) else: - logging.info("Exporting with:") + if module: + logging.info("Re-exporting with:") + else: + logging.info("Exporting with:") logging.info(f"inputs: {self.example_inputs}") logging.info(f"kwargs: {self.example_kwarg_inputs}") logging.info(f"dynamic shapes: {dynamic_shape}") exported_module = export_for_training( - self.model, + self.model if not module else module, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) - # `Module`. - self.pre_autograd_exported_program = exported_module + self.exported_program = exported_module + # Need to store the graph module to record transformation passes. + # Persisting those changes back to the ExportedProgram will require + # an additional export(). self.pre_autograd_graph_module = exported_module.module() if hasattr(self.args, "export_only") and self.args.export_only: torch.export.save(exported_module, self.args.output_name) @@ -382,7 +403,7 @@ def export_to_edge(self) -> "LLMEdgeManager": # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if self.pre_autograd_graph_module is None: + if self.exported_program is None: # Run export() if it didn't run self.export() @@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager": return_value=False, ) + # Prior to export, persist the changes to the pre autograd + # graph module back to the source-of-truth ExportedProgram. + self.export(self.pre_autograd_graph_module) with override_export_behaviour: self.edge_manager = export_to_edge( - self.pre_autograd_graph_module, # pyre-fixme[6] + self.exported_program.module(), # pyre-fixme[6] self.example_inputs, example_kwarg_inputs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, @@ -441,9 +465,14 @@ def to_edge_transform_and_lower( ) -> "LLMEdgeManager": if partitioners is None: logging.info("No partitioner provided, skipping backend lowering...") + + # Prior to export, persist the changes to the pre autograd + # graph module back to the source-of-truth ExportedProgram. + self.export(self.pre_autograd_graph_module) + edge_config = self._get_edge_config() self.edge_manager = to_edge_transform_and_lower( - self.pre_autograd_exported_program, + self.exported_program, partitioner=partitioners, compile_config=edge_config, constant_methods=self.metadata, From 9978148b8779eea8ec15d7a0d7120b2f505e75e3 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 11 Mar 2025 11:16:25 -0700 Subject: [PATCH 2/2] Graph module as SOT --- extension/llm/export/builder.py | 59 ++++++++++++++------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 6ac577bae59..9bca126f027 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -89,9 +89,10 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - self.exported_program: Optional[ExportedProgram] = None - # Self.exported_program's pre-autograd graph module, for running - # transform passes on the graph prior to torch.export(). + # Note: treat this as the source of truth for the result of + # torch.export'ing a model. If the overall ExportedProgram is needed, + # make sure to re-export this graph module to persist any changes. See + # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921 self.pre_autograd_graph_module: Optional[torch.nn.Module] = None self.modelname = modelname self.max_seq_len = max_seq_len @@ -186,21 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig: ) return edge_config - def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager": - """ - Exports the model pre-autograd. This is not a full export, since it uses - torch.export_for_training() to keep autograd-safe ops from getting decomposed. - The full torch.export() if called later on during to_edge() or - to_edge_transform_and_lower(). - - The optional `module` argument is included so that the user can re-export - an already-exported module's ExportedProgram's graph module, to persiste - the changes into a new ExportedProgram. - - Args: - module (Optional[torch.nn.Module]): module to export. - - """ + def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -237,14 +224,22 @@ def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager": kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) - self.exported_program = exported_module - # Need to store the graph module to record transformation passes. - # Persisting those changes back to the ExportedProgram will require - # an additional export(). - self.pre_autograd_graph_module = exported_module.module() - if hasattr(self.args, "export_only") and self.args.export_only: - torch.export.save(exported_module, self.args.output_name) + return exported_module + def export(self) -> "LLMEdgeManager": + """ + Exports the model pre-autograd. This is not a full export, since it uses + torch.export_for_training() to keep autograd-safe ops from getting decomposed. + The full torch.export() if called later on during to_edge() or + to_edge_transform_and_lower(). + """ + exported_module = self._export() + # Need to store the graph module to record transformation passes. + # Persisting those changes back to an ExportedProgram will require + # an additional export(). + self.pre_autograd_graph_module = exported_module.module() + if hasattr(self.args, "export_only") and self.args.export_only: + torch.export.save(exported_module, self.args.output_name) return self def run_canonical_optimizations(self): @@ -403,7 +398,7 @@ def export_to_edge(self) -> "LLMEdgeManager": # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if self.exported_program is None: + if self.pre_autograd_graph_module is None: # Run export() if it didn't run self.export() @@ -415,12 +410,9 @@ def export_to_edge(self) -> "LLMEdgeManager": return_value=False, ) - # Prior to export, persist the changes to the pre autograd - # graph module back to the source-of-truth ExportedProgram. - self.export(self.pre_autograd_graph_module) with override_export_behaviour: self.edge_manager = export_to_edge( - self.exported_program.module(), # pyre-fixme[6] + self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs, example_kwarg_inputs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, @@ -466,13 +458,12 @@ def to_edge_transform_and_lower( if partitioners is None: logging.info("No partitioner provided, skipping backend lowering...") - # Prior to export, persist the changes to the pre autograd - # graph module back to the source-of-truth ExportedProgram. - self.export(self.pre_autograd_graph_module) + # Need to construct ExportedProgram with the new transformed graph module. + exported_module = self._export(self.pre_autograd_graph_module) edge_config = self._get_edge_config() self.edge_manager = to_edge_transform_and_lower( - self.exported_program, + exported_module, partitioner=partitioners, compile_config=edge_config, constant_methods=self.metadata,