From 329a92f9b6623fe260e6f60a9e5ba6843e4db3eb Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 3 Apr 2025 14:02:51 +0200 Subject: [PATCH] Arm backend: Convert assert to throw TypeError in op_add Asserts are converted to proper raises to ensure graph integrity. Change-Id: I6a4e9b4c6d37e8b10599e9551c437a8a18b731e3 Signed-off-by: Sebastian Larsson --- backends/arm/operators/op_add.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index cb14dcb43d8..1be4a218232 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -41,9 +41,18 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) # Handle int8 (quantized) and int32 - assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + supported_dtypes = [ts.DType.INT8, ts.DType.INT32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' + ) dim_order = ( inputs[0].dim_order @@ -105,15 +114,22 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output) else: # FP32 Add lowering - assert inputs[0].dtype == ts.DType.FP32 - assert output.dtype == ts.DType.FP32 + if inputs[0].dtype != ts.DType.FP32: + raise TypeError( + f"Expected IO data type to be FP32, got {inputs[0].dtype}" + ) input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)