|
| 1 | +""" |
| 2 | +Excerpt from https://github.com/jkiesele/caloGraphNN/blob/6d1127d807bc0dbaefcf1ed804d626272f002404/caloGraphNN_keras.py |
| 3 | +""" |
| 4 | + |
| 5 | +import tensorflow.keras as keras |
| 6 | +K = keras.backend |
| 7 | + |
| 8 | +try: |
| 9 | + from qkeras import QDense, ternary, QActivation |
| 10 | + |
| 11 | + class NamedQDense(QDense): |
| 12 | + def add_weight(self, name=None, **kwargs): |
| 13 | + return super(NamedQDense, self).add_weight(name='%s_%s' % (self.name, name), **kwargs) |
| 14 | + |
| 15 | + def ternary_1_05(): |
| 16 | + return ternary(alpha=1., threshold=0.5) |
| 17 | + |
| 18 | +except ImportError: |
| 19 | + pass |
| 20 | + |
| 21 | +# Hack keras Dense to propagate the layer name into saved weights |
| 22 | +class NamedDense(keras.layers.Dense): |
| 23 | + def add_weight(self, name=None, **kwargs): |
| 24 | + return super(NamedDense, self).add_weight(name='%s_%s' % (self.name, name), **kwargs) |
| 25 | + |
| 26 | +class GarNet(keras.layers.Layer): |
| 27 | + def __init__(self, n_aggregators, n_filters, n_propagate, |
| 28 | + simplified=False, |
| 29 | + collapse=None, |
| 30 | + input_format='xn', |
| 31 | + output_activation='tanh', |
| 32 | + mean_by_nvert=False, |
| 33 | + quantize_transforms=False, |
| 34 | + total_bits = None, |
| 35 | + int_bits = None, |
| 36 | + **kwargs): |
| 37 | + super(GarNet, self).__init__(**kwargs) |
| 38 | + |
| 39 | + self._simplified = simplified |
| 40 | + self._output_activation = output_activation |
| 41 | + self._quantize_transforms = quantize_transforms |
| 42 | + self._total_bits = total_bits |
| 43 | + self._int_bits = int_bits |
| 44 | + self._setup_aux_params(collapse, input_format, mean_by_nvert) |
| 45 | + self._setup_transforms(n_aggregators, n_filters, n_propagate) |
| 46 | + |
| 47 | + def _setup_aux_params(self, collapse, input_format, mean_by_nvert): |
| 48 | + if collapse is None: |
| 49 | + self._collapse = None |
| 50 | + elif collapse in ['mean', 'sum', 'max']: |
| 51 | + self._collapse = collapse |
| 52 | + else: |
| 53 | + raise NotImplementedError('Unsupported collapse operation') |
| 54 | + |
| 55 | + self._input_format = input_format |
| 56 | + self._mean_by_nvert = mean_by_nvert |
| 57 | + |
| 58 | + def _setup_transforms(self, n_aggregators, n_filters, n_propagate): |
| 59 | + if self._quantize_transforms: |
| 60 | + self._input_feature_transform = NamedQDense(n_propagate, |
| 61 | + kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 62 | + bias_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 63 | + name='FLR') |
| 64 | + self._output_feature_transform = NamedQDense(n_filters, kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 65 | + name='Fout') |
| 66 | + if (self._output_activation == None or self._output_activation == "linear"): |
| 67 | + self._output_activation_transform = QActivation("quantized_bits(%i, %i)" %(self._total_bits, self._int_bits)) |
| 68 | + else: |
| 69 | + self._output_activation_transform = QActivation("quantized_%s(%i, %i)" %(self._output_activation, self._total_bits, self._int_bits)) |
| 70 | + else: |
| 71 | + self._input_feature_transform = NamedDense(n_propagate, name='FLR') |
| 72 | + self._output_feature_transform = NamedDense(n_filters, activation=self._output_activation, name='Fout') |
| 73 | + self._output_activation_transform = keras.layers.Activation(self._output_activation) |
| 74 | + |
| 75 | + self._aggregator_distance = NamedDense(n_aggregators, name='S') |
| 76 | + |
| 77 | + self._sublayers = [self._input_feature_transform, self._aggregator_distance, self._output_feature_transform, self._output_activation_transform] |
| 78 | + |
| 79 | + def build(self, input_shape): |
| 80 | + super(GarNet, self).build(input_shape) |
| 81 | + |
| 82 | + if self._input_format == 'x': |
| 83 | + data_shape = input_shape |
| 84 | + elif self._input_format == 'xn': |
| 85 | + data_shape, _ = input_shape |
| 86 | + elif self._input_format == 'xen': |
| 87 | + data_shape, _, _ = input_shape |
| 88 | + data_shape = data_shape[:2] + (data_shape[2] + 1,) |
| 89 | + |
| 90 | + self._build_transforms(data_shape) |
| 91 | + |
| 92 | + for layer in self._sublayers: |
| 93 | + self._trainable_weights.extend(layer.trainable_weights) |
| 94 | + self._non_trainable_weights.extend(layer.non_trainable_weights) |
| 95 | + |
| 96 | + def _build_transforms(self, data_shape): |
| 97 | + self._input_feature_transform.build(data_shape) |
| 98 | + self._aggregator_distance.build(data_shape) |
| 99 | + if self._simplified: |
| 100 | + self._output_activation_transform.build(self._output_feature_transform.build(data_shape[:2] + (self._aggregator_distance.units * self._input_feature_transform.units,))) |
| 101 | + else: |
| 102 | + self._output_activation_transform.build(self._output_feature_transform.build(data_shape[:2] + (data_shape[2] + self._aggregator_distance.units * self._input_feature_transform.units + self._aggregator_distance.units,))) |
| 103 | + |
| 104 | + def call(self, x): |
| 105 | + data, num_vertex, vertex_mask = self._unpack_input(x) |
| 106 | + |
| 107 | + output = self._garnet(data, num_vertex, vertex_mask, |
| 108 | + self._input_feature_transform, |
| 109 | + self._aggregator_distance, |
| 110 | + self._output_feature_transform, |
| 111 | + self._output_activation_transform) |
| 112 | + |
| 113 | + output = self._collapse_output(output) |
| 114 | + |
| 115 | + return output |
| 116 | + |
| 117 | + def _unpack_input(self, x): |
| 118 | + if self._input_format == 'x': |
| 119 | + data = x |
| 120 | + |
| 121 | + vertex_mask = K.cast(K.not_equal(data[..., 3:4], 0.), 'float32') |
| 122 | + num_vertex = K.sum(vertex_mask) |
| 123 | + |
| 124 | + elif self._input_format in ['xn', 'xen']: |
| 125 | + if self._input_format == 'xn': |
| 126 | + data, num_vertex = x |
| 127 | + else: |
| 128 | + data_x, data_e, num_vertex = x |
| 129 | + data = K.concatenate((data_x, K.reshape(data_e, (-1, data_e.shape[1], 1))), axis=-1) |
| 130 | + |
| 131 | + data_shape = K.shape(data) |
| 132 | + B = data_shape[0] |
| 133 | + V = data_shape[1] |
| 134 | + vertex_indices = K.tile(K.expand_dims(K.arange(0, V), axis=0), (B, 1)) # (B, [0..V-1]) |
| 135 | + vertex_mask = K.expand_dims(K.cast(K.less(vertex_indices, K.cast(num_vertex, 'int32')), 'float32'), axis=-1) # (B, V, 1) |
| 136 | + num_vertex = K.cast(num_vertex, 'float32') |
| 137 | + |
| 138 | + return data, num_vertex, vertex_mask |
| 139 | + |
| 140 | + def _garnet(self, data, num_vertex, vertex_mask, in_transform, d_compute, out_transform, act_transform): |
| 141 | + features = in_transform(data) # (B, V, F) |
| 142 | + distance = d_compute(data) # (B, V, S) |
| 143 | + |
| 144 | + edge_weights = vertex_mask * K.exp(-K.square(distance)) # (B, V, S) |
| 145 | + |
| 146 | + if not self._simplified: |
| 147 | + features = K.concatenate([vertex_mask * features, edge_weights], axis=-1) |
| 148 | + |
| 149 | + if self._mean_by_nvert: |
| 150 | + def graph_mean(out, axis): |
| 151 | + s = K.sum(out, axis=axis) |
| 152 | + # reshape just to enable broadcasting |
| 153 | + s = K.reshape(s, (-1, d_compute.units * in_transform.units)) / num_vertex |
| 154 | + s = K.reshape(s, (-1, d_compute.units, in_transform.units)) |
| 155 | + return s |
| 156 | + else: |
| 157 | + graph_mean = K.mean |
| 158 | + |
| 159 | + # vertices -> aggregators |
| 160 | + edge_weights_trans = K.permute_dimensions(edge_weights, (0, 2, 1)) # (B, S, V) |
| 161 | + |
| 162 | + aggregated_mean = self._apply_edge_weights(features, edge_weights_trans, aggregation=graph_mean) # (B, S, F) |
| 163 | + |
| 164 | + if self._simplified: |
| 165 | + aggregated = aggregated_mean |
| 166 | + else: |
| 167 | + aggregated_max = self._apply_edge_weights(features, edge_weights_trans, aggregation=K.max) |
| 168 | + aggregated = K.concatenate([aggregated_max, aggregated_mean], axis=-1) |
| 169 | + |
| 170 | + # aggregators -> vertices |
| 171 | + updated_features = self._apply_edge_weights(aggregated, edge_weights) # (B, V, S*F) |
| 172 | + |
| 173 | + if not self._simplified: |
| 174 | + updated_features = K.concatenate([data, updated_features, edge_weights], axis=-1) |
| 175 | + |
| 176 | + return vertex_mask * act_transform(out_transform(updated_features)) |
| 177 | + |
| 178 | + def _collapse_output(self, output): |
| 179 | + if self._collapse == 'mean': |
| 180 | + if self._mean_by_nvert: |
| 181 | + output = K.sum(output, axis=1) / num_vertex |
| 182 | + else: |
| 183 | + output = K.mean(output, axis=1) |
| 184 | + elif self._collapse == 'sum': |
| 185 | + output = K.sum(output, axis=1) |
| 186 | + elif self._collapse == 'max': |
| 187 | + output = K.max(output, axis=1) |
| 188 | + |
| 189 | + return output |
| 190 | + |
| 191 | + def compute_output_shape(self, input_shape): |
| 192 | + return self._get_output_shape(input_shape, self._output_activation_transform) |
| 193 | + |
| 194 | + def _get_output_shape(self, input_shape, out_transform): |
| 195 | + if self._input_format == 'x': |
| 196 | + data_shape = input_shape |
| 197 | + elif self._input_format == 'xn': |
| 198 | + data_shape, _ = input_shape |
| 199 | + elif self._input_format == 'xen': |
| 200 | + data_shape, _, _ = input_shape |
| 201 | + |
| 202 | + if self._collapse is None: |
| 203 | + return data_shape[:2] + (out_transform.units,) |
| 204 | + else: |
| 205 | + return (data_shape[0], out_transform.units) |
| 206 | + |
| 207 | + def get_config(self): |
| 208 | + config = super(GarNet, self).get_config() |
| 209 | + |
| 210 | + config.update({ |
| 211 | + 'simplified': self._simplified, |
| 212 | + 'collapse': self._collapse, |
| 213 | + 'input_format': self._input_format, |
| 214 | + 'output_activation': self._output_activation, |
| 215 | + 'quantize_transforms': self._quantize_transforms, |
| 216 | + 'mean_by_nvert': self._mean_by_nvert |
| 217 | + }) |
| 218 | + |
| 219 | + self._add_transform_config(config) |
| 220 | + |
| 221 | + return config |
| 222 | + |
| 223 | + def _add_transform_config(self, config): |
| 224 | + config.update({ |
| 225 | + 'n_aggregators': self._aggregator_distance.units, |
| 226 | + 'n_filters': self._output_feature_transform.units, |
| 227 | + 'n_propagate': self._input_feature_transform.units |
| 228 | + }) |
| 229 | + |
| 230 | + @staticmethod |
| 231 | + def _apply_edge_weights(features, edge_weights, aggregation=None): |
| 232 | + features = K.expand_dims(features, axis=1) # (B, 1, v, f) |
| 233 | + edge_weights = K.expand_dims(edge_weights, axis=3) # (B, u, v, 1) |
| 234 | + |
| 235 | + out = edge_weights * features # (B, u, v, f) |
| 236 | + |
| 237 | + if aggregation: |
| 238 | + out = aggregation(out, axis=2) # (B, u, f) |
| 239 | + else: |
| 240 | + try: |
| 241 | + out = K.reshape(out, (-1, edge_weights.shape[1].value, features.shape[-1].value * features.shape[-2].value)) |
| 242 | + except AttributeError: # TF 2 |
| 243 | + out = K.reshape(out, (-1, edge_weights.shape[1], features.shape[-1] * features.shape[-2])) |
| 244 | + |
| 245 | + return out |
| 246 | + |
| 247 | + |
| 248 | +class GarNetStack(GarNet): |
| 249 | + """ |
| 250 | + Stacked version of GarNet. First three arguments to the constructor must be lists of integers. |
| 251 | + Basically offers no performance advantage, but the configuration is consolidated (and is useful |
| 252 | + when e.g. converting the layer to HLS) |
| 253 | + """ |
| 254 | + |
| 255 | + def _setup_transforms(self, n_aggregators, n_filters, n_propagate): |
| 256 | + self._transform_layers = [] |
| 257 | + # inputs are lists |
| 258 | + for it, (p, a, f) in enumerate(zip(n_propagate, n_aggregators, n_filters)): |
| 259 | + if self._quantize_transforms != None: |
| 260 | + input_feature_transform = NamedQDense(p, |
| 261 | + kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 262 | + bias_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 263 | + name=('FLR%d' % it)) |
| 264 | + output_feature_transform = NamedQDense(f, kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits), |
| 265 | + name=('Fout%d' % it)) |
| 266 | + if (self._output_activation == None or self._output_activation == "linear"): |
| 267 | + output_activation_transform = QActivation("quantized_bits(%i, %i)" %(self._total_bits, self._int_bits)) |
| 268 | + else: |
| 269 | + output_activation_transform = QActivation("quantized_%s(%i, %i)" %(self._output_activation, self._total_bits, self._int_bits)) |
| 270 | + else: |
| 271 | + input_feature_transform = NamedDense(p, name=('FLR%d' % it)) |
| 272 | + output_feature_transform = NamedDense(f, name=('Fout%d' % it)) |
| 273 | + output_activation_transform = keras.layers.Activation(self._output_activation) |
| 274 | + |
| 275 | + aggregator_distance = NamedDense(a, name=('S%d' % it)) |
| 276 | + |
| 277 | + self._transform_layers.append((input_feature_transform, aggregator_distance, output_feature_transform)) |
| 278 | + |
| 279 | + self._sublayers = sum((list(layers) for layers in self._transform_layers), []) |
| 280 | + |
| 281 | + def _build_transforms(self, data_shape): |
| 282 | + for in_transform, d_compute, out_transform in self._transform_layers: |
| 283 | + in_transform.build(data_shape) |
| 284 | + d_compute.build(data_shape) |
| 285 | + if self._simplified: |
| 286 | + out_transform.build(data_shape[:2] + (d_compute.units * in_transform.units,)) |
| 287 | + else: |
| 288 | + out_transform.build(data_shape[:2] + (data_shape[2] + d_compute.units * in_transform.units + d_compute.units,)) |
| 289 | + |
| 290 | + data_shape = data_shape[:2] + (out_transform.units,) |
| 291 | + |
| 292 | + def call(self, x): |
| 293 | + data, num_vertex, vertex_mask = self._unpack_input(x) |
| 294 | + |
| 295 | + for in_transform, d_compute, out_transform, act_transform in self._transform_layers: |
| 296 | + data = self._garnet(data, num_vertex, vertex_mask, in_transform, d_compute, out_transform, act_transform) |
| 297 | + output = self._collapse_output(data) |
| 298 | + |
| 299 | + return output |
| 300 | + |
| 301 | + def compute_output_shape(self, input_shape): |
| 302 | + return self._get_output_shape(input_shape, self._transform_layers[-1][2]) |
| 303 | + |
| 304 | + def _add_transform_config(self, config): |
| 305 | + config.update({ |
| 306 | + 'n_propagate': list(ll[0].units for ll in self._transform_layers), |
| 307 | + 'n_aggregators': list(ll[1].units for ll in self._transform_layers), |
| 308 | + 'n_filters': list(ll[2].units for ll in self._transform_layers), |
| 309 | + 'n_sublayers': len(self._transform_layers) |
| 310 | + }) |
0 commit comments