1
1
import numpy as np
2
- from hls4ml .model .hls_layers import FixedPrecisionType
2
+ from hls4ml .model .hls_layers import FixedPrecisionType , Constant
3
3
from hls4ml .converters .onnx .quantizer import QuantNodeQuantizer
4
4
from hls4ml .model .optimizer import OptimizerPass
5
+ from hls4ml .model .optimizer .passes .qkeras import ApplyAlpha
5
6
6
7
class QuantConstantParameters (OptimizerPass ):
7
8
""" Remove Constant from the Qaunt node parameters (but not input[0]) """
@@ -45,29 +46,90 @@ def transform(self, model, node):
45
46
46
47
return True
47
48
48
-
49
- class QuantToBatchNorm (OptimizerPass ):
50
- """ Change Quant node to BatchNormalization input[0]"""
49
+ class QuantFactorizeScale (OptimizerPass ):
50
+ '''
51
+ Extract scale and zero-point from Quant Node
52
+ '''
51
53
def match (self , node ):
54
+ # only matches after the other inputs are already folded
55
+
52
56
is_match = (node .__class__ .__name__ == 'Quant'
53
57
and not node .get_input_node (node .inputs [1 ])
54
58
and not node .get_input_node (node .inputs [2 ])
55
59
and not node .get_input_node (node .inputs [3 ]))
60
+
61
+ # Only match if the scale is not 1s and the zero-point is not 0s
62
+ if is_match and node .get_input_variable () is not None : # to make sure this is a quant node with inputs
63
+ input_shape = node .get_input_variable ().shape
64
+ scale = np .broadcast_to (1 / node .get_attr ("scale" ), input_shape )
65
+ bias = np .broadcast_to (node .get_attr ("zeropt" ), input_shape )
66
+ is_match = is_match and (scale != np .ones_like (scale )).any ()
67
+ is_match = is_match and (bias != np .zeros_like (bias )).any ()
68
+ return is_match
56
69
70
+ def transform (self , model , node ):
71
+ '''
72
+ Insert an ApplyAlpha layer to factorize the scales
73
+ '''
74
+ input_shape = node .get_input_variable ().shape
75
+
76
+ scale = np .broadcast_to (1 / node .get_attr ('scale' ), input_shape )
77
+ bias = np .broadcast_to (node .get_attr ('zeropt' ), input_shape )
78
+ # Unset the scale and zero-point so we don't try to factorize again
79
+ node .set_attr ('scale' , 1 )
80
+ node .set_attr ('zeropt' , 0 )
81
+
82
+ # TODO derive these
83
+ scale_precision = FixedPrecisionType ()
84
+ scale_quantizer = QuantNodeQuantizer (scale_precision )
85
+ bias_precision = FixedPrecisionType ()
86
+
87
+ attrs = {
88
+ 'name' : node .get_attr ('name' ) + '_alpha' ,
89
+ 'class_name' : 'Alpha' ,
90
+ 'inputs' : node .outputs ,
91
+ 'n_in' : node .get_attr ('n_out' ),
92
+ 'n_filt' : node .get_attr ('n_filt' , - 1 ),
93
+ 'reuse_factor' : node .get_attr ('reuse_factor' ),
94
+ 'bias_t' : bias_precision ,
95
+ 'scale_t' : scale_precision ,
96
+ 'Trace' : node .get_attr ('Trace' , False )
97
+ }
98
+ alpha_layer = model .make_node ('ApplyAlpha' , node .name + '_alpha' , attrs , node .outputs )
99
+
100
+ alpha_layer .add_weights (scale , quantizer = scale_quantizer )
101
+ alpha_layer .add_bias (bias , quantizer = None )
102
+ model .insert_node (alpha_layer )
103
+
104
+ return True
105
+
106
+ class QuantToActivation (OptimizerPass ):
107
+ ''' Change Quant node to Activation input[0]'''
108
+ def match (self , node ):
57
109
# only matches after the other inputs are already folded
110
+ is_match = (node .__class__ .__name__ == 'Quant'
111
+ and not isinstance (node .get_input_node (), Constant )
112
+ and not node .get_input_node (node .inputs [1 ])
113
+ and not node .get_input_node (node .inputs [2 ])
114
+ and not node .get_input_node (node .inputs [3 ]))
115
+
116
+ # Only match if the scale is 1s and the zero-point is 0s
117
+ if is_match : # to make sure this is a quant node with inputs
118
+ input_shape = node .get_input_variable ().shape
119
+ scale = np .broadcast_to (1 / node .get_attr ("scale" ), input_shape )
120
+ bias = np .broadcast_to (node .get_attr ("zeropt" ), input_shape )
121
+ is_match = is_match and (scale == np .ones_like (scale )).all ()
122
+ is_match = is_match and (bias == np .zeros_like (bias )).all ()
58
123
return is_match
59
124
60
125
def transform (self , model , node ):
61
- """
62
- Change quant node to BatchNormalization
63
- """
126
+ '''
127
+ Change quant node to Activation
128
+ '''
64
129
input_shape = node .get_input_variable ().shape
65
130
66
131
n_in = np .prod (input_shape )
67
132
68
- bn_scale = np .broadcast_to (1 / node .get_attr ("scale" ), input_shape )
69
- bn_bias = np .broadcast_to (node .get_attr ("zeropt" ), input_shape )
70
-
71
133
rounding_mode = node .get_attr ("rounding_mode" )
72
134
if rounding_mode == "ROUND" :
73
135
bn_round = "AP_RND_CONV"
@@ -89,25 +151,52 @@ def transform(self, model, node):
89
151
raise RuntimeError ("Only scalar bitwidth values are supporeted by the Quant node" )
90
152
bitwidth = int (bitwidth )
91
153
92
- bn_precision = FixedPrecisionType (bitwidth , bitwidth , node .get_attr ("signed" ), bn_round , bn_sat )
93
- bn_quantizer = QuantNodeQuantizer (bn_precision )
154
+ precision = FixedPrecisionType (bitwidth , bitwidth , node .get_attr ("signed" ), bn_round , bn_sat )
155
+ quantizer = QuantNodeQuantizer (precision )
94
156
95
157
attributes = {
96
- "simple" : True ,
97
- "scale" : bn_scale ,
98
- "bias" : bn_bias ,
99
- "quant_precision" : bn_precision ,
100
- "quantizer" : bn_quantizer ,
101
- "scale_precision" : node .get_attr ("scale_precision" ),
102
- "bias_precision" : node .get_attr ("bias_precision" ),
103
- "n_in" : n_in ,
104
- "n_out" : n_in ,
105
- "n_filt" : - 1
158
+ 'activation' : 'linear' ,
159
+ 'precision' : precision ,
160
+ 'n_in' : n_in ,
161
+ 'n_out' : n_in ,
162
+ 'n_filt' : - 1
106
163
}
107
164
108
- bn_layer = model .make_node ("BatchNormalization" , f"bn_ { node .name } " ,
109
- attributes ,
110
- [ node . inputs [ 0 ]], node . outputs )
111
- model .replace_node (node , bn_layer )
165
+ new_node = model .make_node ('Activation' , f' { node .name } _act' ,
166
+ attributes , [ node . inputs [ 0 ]], node . outputs )
167
+ new_node . get_output_variable (). type . precision = precision
168
+ model .replace_node (node , new_node )
112
169
113
170
return True
171
+
172
+ class QuantToConstant (OptimizerPass ):
173
+ '''
174
+ Remove a Quant node that is quantizing a constant.
175
+ Update the attributes of the constant according to the quantization.
176
+ '''
177
+
178
+ def match (self , node ):
179
+ is_match = (node .__class__ .__name__ == 'Quant'
180
+ and isinstance (node .get_input_node (node .inputs [0 ]), Constant ))
181
+ return is_match
182
+
183
+ def transform (self , model , node ):
184
+ const_node = node .get_input_node (node .inputs [0 ])
185
+
186
+ new_val = const_node .value * node .get_attr ('scale' ) + node .get_attr ('zeropt' )
187
+ quantizer = node .get_attr ('quantizer' ) # None if not defined
188
+ if quantizer :
189
+ const_node .set_attr ('quantizer' , quantizer )
190
+ const_node .set_attr ('value' , new_val )
191
+
192
+ quant_precision = node .get_attr ('quant_precision' )
193
+ if quant_precision :
194
+ const_node .set_attr ('quant_precision' , quant_precision )
195
+
196
+ # reinitialize (which also runs quantization if quantizer exists)
197
+ const_node .initialize ()
198
+
199
+ # remove the Quant node
200
+ model .remove_node (node , rewire = True )
201
+
202
+ return True
0 commit comments