diff --git a/codegen/tools/gen_all_oplist.py b/codegen/tools/gen_all_oplist.py index b417444f8dc..5cb93bb9153 100644 --- a/codegen/tools/gen_all_oplist.py +++ b/codegen/tools/gen_all_oplist.py @@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str: return real_path +def _raise_if_check_prim_ops_fail(options): + + # Error out if we have more than one targets registering prim ops. + if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1: + assert ( + options.DEBUG_ONLY_check_prim_ops[0] == "@" + ), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue." + + prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:] + with open(prim_ops_targets_file, "r") as file: + prim_ops_targets = file.read().split() + if len(prim_ops_targets) > 1: + # Yellow bold: \033[33;1m + # Red bold: \033[31;1m + # Green bold: \033[32;1m + error = ( + "It seems this target is depending on more than 1 `prim_ops_registry` targets: " + + f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: ' + + "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m" + + "\nTo find out the dependency chain, run the following command: " + + f'\n \033[32;1mbuck2 cquery "allpaths(, {prim_ops_targets[0]})"\033[0m' + ) + raise Exception(error) + + def main(argv: List[Any]) -> None: """This binary generates 3 files: @@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None: default=False, required=False, ) + parser.add_argument( + "--DEBUG-ONLY-check-prim-ops", + "--DEBUG_ONLY_check_prim_ops", + help=( + "Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1." + ), + required=False, + ) options = parser.parse_args(argv) + _raise_if_check_prim_ops_fail(options) + # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either: # 1. a yaml file containing selected ops (could be empty), or # 2. a non-empty list of yaml files in the `model_file_list_path` or @@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None: debug_info_2 = ",".join( model_dict["operators"][op_name]["debug_info"] ) - error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}" + # Yellow bold: \033[33;1m + # Red bold: \033[31;1m + # Green bold: \033[32;1m + error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m" if "//" not in debug_info_1 and "//" not in debug_info_2: error += "\nWe can't determine what BUCK targets these model files belong to." tail = "." else: error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n" - error += f'\n buck2 cquery "allpaths(, {debug_info_1})"' - error += f'\n buck2 cquery "allpaths(, {debug_info_2})"' + error += f'\n \033[32;1mbuck2 cquery "allpaths(, {debug_info_1})"\033[0m' + error += f'\n \033[32;1mbuck2 cquery "allpaths(, {debug_info_2})"\033[0m' tail = "as well as results from BUCK commands listed above." error += ( diff --git a/shim/xplat/executorch/codegen/codegen.bzl b/shim/xplat/executorch/codegen/codegen.bzl index 8e0e89eda57..4b69a2cf4a0 100644 --- a/shim/xplat/executorch/codegen/codegen.bzl +++ b/shim/xplat/executorch/codegen/codegen.bzl @@ -692,6 +692,7 @@ def executorch_ops_check( "--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " + "--allow_include_all_overloads " + "--check_ops_not_overlapping " + + "--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " + "--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])), define_static_target = False, platforms = kwargs.pop("platforms", get_default_executorch_platforms()),