Skip to content

Error out if registering prim ops multiple times #8172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions codegen/tools/gen_all_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
return real_path


def _raise_if_check_prim_ops_fail(options):

# Error out if we have more than one targets registering prim ops.
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
assert (
options.DEBUG_ONLY_check_prim_ops[0] == "@"
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."

prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
with open(prim_ops_targets_file, "r") as file:
prim_ops_targets = file.read().split()
if len(prim_ops_targets) > 1:
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = (
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
+ "\nTo find out the dependency chain, run the following command: "
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
)
raise Exception(error)


def main(argv: List[Any]) -> None:
"""This binary generates 3 files:

Expand Down Expand Up @@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
default=False,
required=False,
)
parser.add_argument(
"--DEBUG-ONLY-check-prim-ops",
"--DEBUG_ONLY_check_prim_ops",
help=(
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
),
required=False,
)
options = parser.parse_args(argv)

_raise_if_check_prim_ops_fail(options)

# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
# 1. a yaml file containing selected ops (could be empty), or
# 2. a non-empty list of yaml files in the `model_file_list_path` or
Expand Down Expand Up @@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
debug_info_2 = ",".join(
model_dict["operators"][op_name]["debug_info"]
)
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
if "//" not in debug_info_1 and "//" not in debug_info_2:
error += "\nWe can't determine what BUCK targets these model files belong to."
tail = "."
else:
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
tail = "as well as results from BUCK commands listed above."

error += (
Expand Down
1 change: 1 addition & 0 deletions shim/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def executorch_ops_check(
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
"--allow_include_all_overloads " +
"--check_ops_not_overlapping " +
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
define_static_target = False,
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),
Expand Down
Loading