diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index d2b66a34be2..93b3f85c529 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -178,18 +178,14 @@ def export_program( return expo_program -# Export the model and lower it to an EdgeProgramManager (in edge IR). -def export_to_edge( - model: torch.nn.Module, - inputs: tuple[object, ...], +def lower_ep_to_edge( + expo_program: ExportedProgram, dump_graphs: bool = False, constant_methods: Optional[dict[str, object]] = None, ) -> EdgeProgramManager: - assert isinstance(model, torch.nn.Module), "model should be an nn.Module" - - # Export the model into an ExportedProgram. - expo_program = export_program(model, inputs) - + """ + Lower an ExportedProgram to an EdgeProgramManager (in edge IR). + """ # Call to_edge to convert the graph to edge IR. # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704) edge_prog_manager = to_edge( @@ -215,6 +211,23 @@ def export_to_edge( logging.info( edge_prog_manager.exported_program().graph_module.graph.print_tabular() ) + return edge_prog_manager + + +# Export the model and lower it to an EdgeProgramManager (in edge IR). +def export_to_edge( + model: torch.nn.Module, + inputs: tuple[object, ...], + dump_graphs: bool = False, + constant_methods: Optional[dict[str, object]] = None, +) -> EdgeProgramManager: + assert isinstance(model, torch.nn.Module), "model should be an nn.Module" + + # Export the model into an ExportedProgram. + expo_program = export_program(model, inputs) + + # Lower the model to edge IR. + edge_prog_manager = lower_ep_to_edge(expo_program, dump_graphs, constant_methods) return edge_prog_manager