File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -37,18 +37,24 @@ def define_node(
37
37
inputs : List [TosaArg ],
38
38
output : TosaArg ,
39
39
) -> None :
40
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
41
+ raise TypeError (
42
+ f"All IO needs to have the same data type, got: "
43
+ f"{ inputs [0 ].dtype = } , { inputs [1 ].dtype = } and { output .dtype = } "
44
+ )
40
45
41
- assert inputs [0 ].dtype == inputs [1 ].dtype , "Both inputs must be of same type"
42
- assert inputs [0 ].dtype in [
43
- ts .DType .INT8 ,
44
- ts .DType .FP32 ,
45
- ], "Only int8 and float32 supported"
46
- # aten.bmm maps directly to MATMUL
47
46
# NOTE: For now, only INT8 & FP32 is supported
47
+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
48
+ for input in inputs :
49
+ if input .dtype not in supported_dtypes :
50
+ raise TypeError (
51
+ f'IO data type needs to be { supported_dtypes } , got "{ input .dtype } "'
52
+ )
53
+
54
+ # aten.bmm maps directly to MATMUL
48
55
49
56
# For INT8, we need to get the zero points and add an intermediate tensor
50
57
# for a later rescale.
51
-
52
58
if inputs [0 ].dtype == ts .DType .INT8 :
53
59
input_qparams = get_input_qparams (node )
54
60
input0_zp = input_qparams [0 ].zp
You can’t perform that action at this time.
0 commit comments