From e15679195aae1aacf3e416140db6f771e147b2a9 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 2 Apr 2025 10:41:20 +0200 Subject: [PATCH 1/2] Arm backend: Allow graph_module to be None in get_node_debug_info get_node_debug_info() without graph_module will now only print the node information. Change-Id: I12d9cc30eafc9c1fadfb50dccaf39fc3b4b5663b Signed-off-by: Sebastian Larsson --- backends/arm/tosa_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 54efdb1bbf2..92972182097 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -30,9 +30,13 @@ def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): logger.info(get_node_debug_info(node, graph_module)) -def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: +def get_node_debug_info( + node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None +) -> str: output = ( f" {inspect_node(graph=graph_module.graph, node=node)}\n" + if graph_module + else "" "-- NODE DEBUG INFO --\n" f" Op is {node.op}\n" f" Name is {node.name}\n" From 4ac001466ef4d413b43561ad73162a80a294cb53 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Tue, 1 Apr 2025 16:38:16 +0200 Subject: [PATCH 2/2] Arm backend: Convert assert to throw TypeError in arm_pass_utils There's a risk with using asserts in production code as it might get optimized out. TypeError is more explicit and conveys more information abour what has gone wrong. Change-Id: I2c811f9290e6beefd9a4a99d83ec9220209a56a8 Signed-off-by: Sebastian Larsson --- backends/arm/_passes/arm_pass_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index afb2d82a2fc..1e0c21239e2 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -13,6 +13,7 @@ import torch import torch.fx +from executorch.backends.arm.tosa_utils import get_node_debug_info from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -169,9 +170,13 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: else: fake_tensor = node.meta["val"] - assert isinstance( - fake_tensor, FakeTensor - ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' + if not isinstance(fake_tensor, FakeTensor): + raise TypeError( + f'Expected a FakeTensor in meta["val"] of node {node}, but got ' + f"{type(fake_tensor).__name__}\n" + f"{get_node_debug_info(node)}" + ) + return fake_tensor