diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index afb2d82a2fc..1e0c21239e2 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -13,6 +13,7 @@ import torch import torch.fx +from executorch.backends.arm.tosa_utils import get_node_debug_info from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -169,9 +170,13 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: else: fake_tensor = node.meta["val"] - assert isinstance( - fake_tensor, FakeTensor - ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' + if not isinstance(fake_tensor, FakeTensor): + raise TypeError( + f'Expected a FakeTensor in meta["val"] of node {node}, but got ' + f"{type(fake_tensor).__name__}\n" + f"{get_node_debug_info(node)}" + ) + return fake_tensor diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 54efdb1bbf2..92972182097 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -30,9 +30,13 @@ def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): logger.info(get_node_debug_info(node, graph_module)) -def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: +def get_node_debug_info( + node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None +) -> str: output = ( f" {inspect_node(graph=graph_module.graph, node=node)}\n" + if graph_module + else "" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n"