From b3d2e87adf92cc6e78ebc6db69c8b2195f7ca5f7 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 1 Apr 2025 16:23:00 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- exir/backend/backend_details.py | 72 ++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 248d03f2b05..513ae7c64b3 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -50,15 +50,6 @@ class BackendDetails(ABC): the decorators, this interface will be static, abstract and all inheritances are enforced to implement this method. - Args: - edge_program: The original exported program. It will not be modified in place. - compile_specs: List of values needed for compilation - - Returns: - PreprocessResult: It wraps the following information: - processed_bytes -> bytes: A compiled blob - a binary that can run the desired program in the backend. - debug_handle_map (Optional[Dict[int, Tuple[int]]]): For profiling purposes, a map from the node_id in the final graph (either EXIR or the user's self-defined IR) - to debug handle id attached in the original exported program. """ @staticmethod @@ -70,6 +61,69 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: + """ + Preprocesses an edge program and returns the preprocess result fo the given backend + + Args: + edge_program: The original exported program. It will not be modified in place. + compile_specs: List of values needed for compilation + + Returns: + PreprocessResult: It wraps the following information: + processed_bytes -> bytes: A compiled blob - a binary that can run the desired + program in the backend. + debug_handle_map (Optional[Dict[int, Tuple[int]]]): For profiling purposes, a + map from the node_id in the final graph (either EXIR or the user's self-defined + IR) to debug handle id attached in the original exported program. + """ # Users should return a compiled blob - a binary that can run the desired # program in the backend. pass + + @classmethod + def preprocess_multimethod( + cls, + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Runs preprocess on all partitioned Edge Programs across multiple methods. This allows + backends to share information across partitioned graphs. Backend can serialize shared + data by putting the shared data into the data_store_output of the preprocess results. + This will record the shared data used by that specific partition. + + Default implementation is running the existing preprocess implementation on all + + Args: + edge_programs: Dictionary mapping the method name to a list of all the partitioned + edge_programs from that method to be lowered. + compile_specs: Dictionary mapping the method name to a list of compile_specs. The + list of compile specs maps directly to the list of edge_programs for the + same given method name i.e. edge_program[method_name][i] --> compile_specs[method_name][i] + + Returns: + Dictionary mapping the method name to a list of PreprocessResults. The list of + PreprocessResults maps directly to the list of edge_programs for the same given + method name. i.e. edge_program[method_name][i] --> result[method_name][i] + + + """ + preprocess_results = {} + for method_name, programs in edge_programs.items(): + assert ( + method_name in compile_specs + ), f"Error: missing compile specs for {method_name}" + compile_specs_for_method = compile_specs[method_name] + assert len(compile_specs_for_method) == len( + programs + ), f"Error: method {method_name} has {len(programs)} partitions but only {len(compile_specs_for_method)}" + results_for_method = [] + for program, compile_spec_for_program in zip( + programs, compile_specs_for_method + ): + preprocess_result = cls.preprocess(program, compile_spec_for_program) + results_for_method.append(preprocess_result) + + preprocess_results[method_name] = results_for_method + + return preprocess_results