diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index eb8dd462378..9bca126f027 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -89,7 +89,10 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - self.pre_autograd_exported_program: Optional[ExportedProgram] = None + # Note: treat this as the source of truth for the result of + # torch.export'ing a model. If the overall ExportedProgram is needed, + # make sure to re-export this graph module to persist any changes. See + # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921 self.pre_autograd_graph_module: Optional[torch.nn.Module] = None self.modelname = modelname self.max_seq_len = max_seq_len @@ -184,7 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig: ) return edge_config - def export(self) -> "LLMEdgeManager": + def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -201,29 +204,42 @@ def export(self) -> "LLMEdgeManager": # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details exported_module = torch.export.export( - self.model, + self.model if not module else module, self.example_inputs, self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, strict=True, ) else: - logging.info("Exporting with:") + if module: + logging.info("Re-exporting with:") + else: + logging.info("Exporting with:") logging.info(f"inputs: {self.example_inputs}") logging.info(f"kwargs: {self.example_kwarg_inputs}") logging.info(f"dynamic shapes: {dynamic_shape}") exported_module = export_for_training( - self.model, + self.model if not module else module, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) - # `Module`. - self.pre_autograd_exported_program = exported_module - self.pre_autograd_graph_module = exported_module.module() - if hasattr(self.args, "export_only") and self.args.export_only: - torch.export.save(exported_module, self.args.output_name) + return exported_module + def export(self) -> "LLMEdgeManager": + """ + Exports the model pre-autograd. This is not a full export, since it uses + torch.export_for_training() to keep autograd-safe ops from getting decomposed. + The full torch.export() if called later on during to_edge() or + to_edge_transform_and_lower(). + """ + exported_module = self._export() + # Need to store the graph module to record transformation passes. + # Persisting those changes back to an ExportedProgram will require + # an additional export(). + self.pre_autograd_graph_module = exported_module.module() + if hasattr(self.args, "export_only") and self.args.export_only: + torch.export.save(exported_module, self.args.output_name) return self def run_canonical_optimizations(self): @@ -441,9 +457,13 @@ def to_edge_transform_and_lower( ) -> "LLMEdgeManager": if partitioners is None: logging.info("No partitioner provided, skipping backend lowering...") + + # Need to construct ExportedProgram with the new transformed graph module. + exported_module = self._export(self.pre_autograd_graph_module) + edge_config = self._get_edge_config() self.edge_manager = to_edge_transform_and_lower( - self.pre_autograd_exported_program, + exported_module, partitioner=partitioners, compile_config=edge_config, constant_methods=self.metadata,