Skip to content

Commit 3072890

Browse files
Arm backend: Convert asserts to raise errors in op_bmm (#10424)
Asserts are converted to proper raises to ensure graph integrity. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 9efb909 commit 3072890

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,24 @@ def define_node(
3737
inputs: List[TosaArg],
3838
output: TosaArg,
3939
) -> None:
40+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
41+
raise TypeError(
42+
f"All IO needs to have the same data type, got: "
43+
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
44+
)
4045

41-
assert inputs[0].dtype == inputs[1].dtype, "Both inputs must be of same type"
42-
assert inputs[0].dtype in [
43-
ts.DType.INT8,
44-
ts.DType.FP32,
45-
], "Only int8 and float32 supported"
46-
# aten.bmm maps directly to MATMUL
4746
# NOTE: For now, only INT8 & FP32 is supported
47+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
48+
for input in inputs:
49+
if input.dtype not in supported_dtypes:
50+
raise TypeError(
51+
f'IO data type needs to be {supported_dtypes}, got "{input.dtype}"'
52+
)
53+
54+
# aten.bmm maps directly to MATMUL
4855

4956
# For INT8, we need to get the zero points and add an intermediate tensor
5057
# for a later rescale.
51-
5258
if inputs[0].dtype == ts.DType.INT8:
5359
input_qparams = get_input_qparams(node)
5460
input0_zp = input_qparams[0].zp

0 commit comments

Comments
 (0)