Skip to content

Commit 4c5d89d

Browse files
authored
Merge pull request #344 from yiiyama/topic-config-garnet
GarNet and GarNetStack in config.py
2 parents 0ef2766 + 9dd4000 commit 4c5d89d

File tree

8 files changed

+428
-8
lines changed

8 files changed

+428
-8
lines changed

contrib/__init__.py

Whitespace-only changes.

contrib/garnet.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
"""
2+
Excerpt from https://github.com/jkiesele/caloGraphNN/blob/6d1127d807bc0dbaefcf1ed804d626272f002404/caloGraphNN_keras.py
3+
"""
4+
5+
import tensorflow.keras as keras
6+
K = keras.backend
7+
8+
try:
9+
from qkeras import QDense, ternary, QActivation
10+
11+
class NamedQDense(QDense):
12+
def add_weight(self, name=None, **kwargs):
13+
return super(NamedQDense, self).add_weight(name='%s_%s' % (self.name, name), **kwargs)
14+
15+
def ternary_1_05():
16+
return ternary(alpha=1., threshold=0.5)
17+
18+
except ImportError:
19+
pass
20+
21+
# Hack keras Dense to propagate the layer name into saved weights
22+
class NamedDense(keras.layers.Dense):
23+
def add_weight(self, name=None, **kwargs):
24+
return super(NamedDense, self).add_weight(name='%s_%s' % (self.name, name), **kwargs)
25+
26+
class GarNet(keras.layers.Layer):
27+
def __init__(self, n_aggregators, n_filters, n_propagate,
28+
simplified=False,
29+
collapse=None,
30+
input_format='xn',
31+
output_activation='tanh',
32+
mean_by_nvert=False,
33+
quantize_transforms=False,
34+
total_bits = None,
35+
int_bits = None,
36+
**kwargs):
37+
super(GarNet, self).__init__(**kwargs)
38+
39+
self._simplified = simplified
40+
self._output_activation = output_activation
41+
self._quantize_transforms = quantize_transforms
42+
self._total_bits = total_bits
43+
self._int_bits = int_bits
44+
self._setup_aux_params(collapse, input_format, mean_by_nvert)
45+
self._setup_transforms(n_aggregators, n_filters, n_propagate)
46+
47+
def _setup_aux_params(self, collapse, input_format, mean_by_nvert):
48+
if collapse is None:
49+
self._collapse = None
50+
elif collapse in ['mean', 'sum', 'max']:
51+
self._collapse = collapse
52+
else:
53+
raise NotImplementedError('Unsupported collapse operation')
54+
55+
self._input_format = input_format
56+
self._mean_by_nvert = mean_by_nvert
57+
58+
def _setup_transforms(self, n_aggregators, n_filters, n_propagate):
59+
if self._quantize_transforms:
60+
self._input_feature_transform = NamedQDense(n_propagate,
61+
kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
62+
bias_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
63+
name='FLR')
64+
self._output_feature_transform = NamedQDense(n_filters, kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
65+
name='Fout')
66+
if (self._output_activation == None or self._output_activation == "linear"):
67+
self._output_activation_transform = QActivation("quantized_bits(%i, %i)" %(self._total_bits, self._int_bits))
68+
else:
69+
self._output_activation_transform = QActivation("quantized_%s(%i, %i)" %(self._output_activation, self._total_bits, self._int_bits))
70+
else:
71+
self._input_feature_transform = NamedDense(n_propagate, name='FLR')
72+
self._output_feature_transform = NamedDense(n_filters, activation=self._output_activation, name='Fout')
73+
self._output_activation_transform = keras.layers.Activation(self._output_activation)
74+
75+
self._aggregator_distance = NamedDense(n_aggregators, name='S')
76+
77+
self._sublayers = [self._input_feature_transform, self._aggregator_distance, self._output_feature_transform, self._output_activation_transform]
78+
79+
def build(self, input_shape):
80+
super(GarNet, self).build(input_shape)
81+
82+
if self._input_format == 'x':
83+
data_shape = input_shape
84+
elif self._input_format == 'xn':
85+
data_shape, _ = input_shape
86+
elif self._input_format == 'xen':
87+
data_shape, _, _ = input_shape
88+
data_shape = data_shape[:2] + (data_shape[2] + 1,)
89+
90+
self._build_transforms(data_shape)
91+
92+
for layer in self._sublayers:
93+
self._trainable_weights.extend(layer.trainable_weights)
94+
self._non_trainable_weights.extend(layer.non_trainable_weights)
95+
96+
def _build_transforms(self, data_shape):
97+
self._input_feature_transform.build(data_shape)
98+
self._aggregator_distance.build(data_shape)
99+
if self._simplified:
100+
self._output_activation_transform.build(self._output_feature_transform.build(data_shape[:2] + (self._aggregator_distance.units * self._input_feature_transform.units,)))
101+
else:
102+
self._output_activation_transform.build(self._output_feature_transform.build(data_shape[:2] + (data_shape[2] + self._aggregator_distance.units * self._input_feature_transform.units + self._aggregator_distance.units,)))
103+
104+
def call(self, x):
105+
data, num_vertex, vertex_mask = self._unpack_input(x)
106+
107+
output = self._garnet(data, num_vertex, vertex_mask,
108+
self._input_feature_transform,
109+
self._aggregator_distance,
110+
self._output_feature_transform,
111+
self._output_activation_transform)
112+
113+
output = self._collapse_output(output)
114+
115+
return output
116+
117+
def _unpack_input(self, x):
118+
if self._input_format == 'x':
119+
data = x
120+
121+
vertex_mask = K.cast(K.not_equal(data[..., 3:4], 0.), 'float32')
122+
num_vertex = K.sum(vertex_mask)
123+
124+
elif self._input_format in ['xn', 'xen']:
125+
if self._input_format == 'xn':
126+
data, num_vertex = x
127+
else:
128+
data_x, data_e, num_vertex = x
129+
data = K.concatenate((data_x, K.reshape(data_e, (-1, data_e.shape[1], 1))), axis=-1)
130+
131+
data_shape = K.shape(data)
132+
B = data_shape[0]
133+
V = data_shape[1]
134+
vertex_indices = K.tile(K.expand_dims(K.arange(0, V), axis=0), (B, 1)) # (B, [0..V-1])
135+
vertex_mask = K.expand_dims(K.cast(K.less(vertex_indices, K.cast(num_vertex, 'int32')), 'float32'), axis=-1) # (B, V, 1)
136+
num_vertex = K.cast(num_vertex, 'float32')
137+
138+
return data, num_vertex, vertex_mask
139+
140+
def _garnet(self, data, num_vertex, vertex_mask, in_transform, d_compute, out_transform, act_transform):
141+
features = in_transform(data) # (B, V, F)
142+
distance = d_compute(data) # (B, V, S)
143+
144+
edge_weights = vertex_mask * K.exp(-K.square(distance)) # (B, V, S)
145+
146+
if not self._simplified:
147+
features = K.concatenate([vertex_mask * features, edge_weights], axis=-1)
148+
149+
if self._mean_by_nvert:
150+
def graph_mean(out, axis):
151+
s = K.sum(out, axis=axis)
152+
# reshape just to enable broadcasting
153+
s = K.reshape(s, (-1, d_compute.units * in_transform.units)) / num_vertex
154+
s = K.reshape(s, (-1, d_compute.units, in_transform.units))
155+
return s
156+
else:
157+
graph_mean = K.mean
158+
159+
# vertices -> aggregators
160+
edge_weights_trans = K.permute_dimensions(edge_weights, (0, 2, 1)) # (B, S, V)
161+
162+
aggregated_mean = self._apply_edge_weights(features, edge_weights_trans, aggregation=graph_mean) # (B, S, F)
163+
164+
if self._simplified:
165+
aggregated = aggregated_mean
166+
else:
167+
aggregated_max = self._apply_edge_weights(features, edge_weights_trans, aggregation=K.max)
168+
aggregated = K.concatenate([aggregated_max, aggregated_mean], axis=-1)
169+
170+
# aggregators -> vertices
171+
updated_features = self._apply_edge_weights(aggregated, edge_weights) # (B, V, S*F)
172+
173+
if not self._simplified:
174+
updated_features = K.concatenate([data, updated_features, edge_weights], axis=-1)
175+
176+
return vertex_mask * act_transform(out_transform(updated_features))
177+
178+
def _collapse_output(self, output):
179+
if self._collapse == 'mean':
180+
if self._mean_by_nvert:
181+
output = K.sum(output, axis=1) / num_vertex
182+
else:
183+
output = K.mean(output, axis=1)
184+
elif self._collapse == 'sum':
185+
output = K.sum(output, axis=1)
186+
elif self._collapse == 'max':
187+
output = K.max(output, axis=1)
188+
189+
return output
190+
191+
def compute_output_shape(self, input_shape):
192+
return self._get_output_shape(input_shape, self._output_activation_transform)
193+
194+
def _get_output_shape(self, input_shape, out_transform):
195+
if self._input_format == 'x':
196+
data_shape = input_shape
197+
elif self._input_format == 'xn':
198+
data_shape, _ = input_shape
199+
elif self._input_format == 'xen':
200+
data_shape, _, _ = input_shape
201+
202+
if self._collapse is None:
203+
return data_shape[:2] + (out_transform.units,)
204+
else:
205+
return (data_shape[0], out_transform.units)
206+
207+
def get_config(self):
208+
config = super(GarNet, self).get_config()
209+
210+
config.update({
211+
'simplified': self._simplified,
212+
'collapse': self._collapse,
213+
'input_format': self._input_format,
214+
'output_activation': self._output_activation,
215+
'quantize_transforms': self._quantize_transforms,
216+
'mean_by_nvert': self._mean_by_nvert
217+
})
218+
219+
self._add_transform_config(config)
220+
221+
return config
222+
223+
def _add_transform_config(self, config):
224+
config.update({
225+
'n_aggregators': self._aggregator_distance.units,
226+
'n_filters': self._output_feature_transform.units,
227+
'n_propagate': self._input_feature_transform.units
228+
})
229+
230+
@staticmethod
231+
def _apply_edge_weights(features, edge_weights, aggregation=None):
232+
features = K.expand_dims(features, axis=1) # (B, 1, v, f)
233+
edge_weights = K.expand_dims(edge_weights, axis=3) # (B, u, v, 1)
234+
235+
out = edge_weights * features # (B, u, v, f)
236+
237+
if aggregation:
238+
out = aggregation(out, axis=2) # (B, u, f)
239+
else:
240+
try:
241+
out = K.reshape(out, (-1, edge_weights.shape[1].value, features.shape[-1].value * features.shape[-2].value))
242+
except AttributeError: # TF 2
243+
out = K.reshape(out, (-1, edge_weights.shape[1], features.shape[-1] * features.shape[-2]))
244+
245+
return out
246+
247+
248+
class GarNetStack(GarNet):
249+
"""
250+
Stacked version of GarNet. First three arguments to the constructor must be lists of integers.
251+
Basically offers no performance advantage, but the configuration is consolidated (and is useful
252+
when e.g. converting the layer to HLS)
253+
"""
254+
255+
def _setup_transforms(self, n_aggregators, n_filters, n_propagate):
256+
self._transform_layers = []
257+
# inputs are lists
258+
for it, (p, a, f) in enumerate(zip(n_propagate, n_aggregators, n_filters)):
259+
if self._quantize_transforms != None:
260+
input_feature_transform = NamedQDense(p,
261+
kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
262+
bias_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
263+
name=('FLR%d' % it))
264+
output_feature_transform = NamedQDense(f, kernel_quantizer="quantized_bits(%i,%i,0,alpha=1)" %(self._total_bits, self._int_bits),
265+
name=('Fout%d' % it))
266+
if (self._output_activation == None or self._output_activation == "linear"):
267+
output_activation_transform = QActivation("quantized_bits(%i, %i)" %(self._total_bits, self._int_bits))
268+
else:
269+
output_activation_transform = QActivation("quantized_%s(%i, %i)" %(self._output_activation, self._total_bits, self._int_bits))
270+
else:
271+
input_feature_transform = NamedDense(p, name=('FLR%d' % it))
272+
output_feature_transform = NamedDense(f, name=('Fout%d' % it))
273+
output_activation_transform = keras.layers.Activation(self._output_activation)
274+
275+
aggregator_distance = NamedDense(a, name=('S%d' % it))
276+
277+
self._transform_layers.append((input_feature_transform, aggregator_distance, output_feature_transform))
278+
279+
self._sublayers = sum((list(layers) for layers in self._transform_layers), [])
280+
281+
def _build_transforms(self, data_shape):
282+
for in_transform, d_compute, out_transform in self._transform_layers:
283+
in_transform.build(data_shape)
284+
d_compute.build(data_shape)
285+
if self._simplified:
286+
out_transform.build(data_shape[:2] + (d_compute.units * in_transform.units,))
287+
else:
288+
out_transform.build(data_shape[:2] + (data_shape[2] + d_compute.units * in_transform.units + d_compute.units,))
289+
290+
data_shape = data_shape[:2] + (out_transform.units,)
291+
292+
def call(self, x):
293+
data, num_vertex, vertex_mask = self._unpack_input(x)
294+
295+
for in_transform, d_compute, out_transform, act_transform in self._transform_layers:
296+
data = self._garnet(data, num_vertex, vertex_mask, in_transform, d_compute, out_transform, act_transform)
297+
output = self._collapse_output(data)
298+
299+
return output
300+
301+
def compute_output_shape(self, input_shape):
302+
return self._get_output_shape(input_shape, self._transform_layers[-1][2])
303+
304+
def _add_transform_config(self, config):
305+
config.update({
306+
'n_propagate': list(ll[0].units for ll in self._transform_layers),
307+
'n_aggregators': list(ll[1].units for ll in self._transform_layers),
308+
'n_filters': list(ll[2].units for ll in self._transform_layers),
309+
'n_sublayers': len(self._transform_layers)
310+
})

hls4ml/converters/keras/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ def parse_input_layer(keras_layer, input_names, input_shapes, data_reader, confi
1313
layer = parse_default_keras_layer(keras_layer, input_names)
1414

1515
layer['input_shape'] = keras_layer['config']['batch_input_shape'][1:]
16-
if keras_layer['config']['dtype'] == 'int32':
16+
17+
dtype = keras_layer['config']['dtype']
18+
if dtype.startswith('int') or dtype.startswith('uint'):
1719
layer['type_name'] = 'integer_input_t'
18-
layer['precision'] = IntegerPrecisionType(width=32)
20+
width = int(dtype[dtype.index('int') + 3:])
21+
signed = (not dtype.startswith('u'))
22+
layer['precision'] = IntegerPrecisionType(width=width, signed=signed)
23+
# elif bool, q[u]int, ...
24+
1925
output_shape = keras_layer['config']['batch_input_shape']
2026

2127
return layer, output_shape

hls4ml/converters/keras/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ def parse_garnet_layer(keras_layer, input_names, input_shapes, data_reader, conf
88

99
if not keras_layer['config']['simplified']:
1010
raise Exception('HLS GarNet is compatible only with keras GarNet with simplified=True')
11-
if keras_layer['config']['output_activation'] is not None:
12-
raise Exception('HLS GarNet cannot have output activation')
11+
if keras_layer['config']['output_activation'] not in [None, 'linear']:
12+
raise Exception('HLS GarNet cannot have nonlinear output activation')
1313

1414
layer = parse_default_keras_layer(keras_layer, input_names)
1515

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef X_HLS_MATH_H
2+
#define X_HLS_MATH_H
3+
4+
#include <cmath>
5+
#include "ap_fixed.h"
6+
7+
namespace hls {
8+
9+
template<class T>
10+
static T exp(const T x) {
11+
return (T) std::exp(x.to_double());
12+
}
13+
14+
}
15+
#endif

hls4ml/templates/vivado/build_lib.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ LIB_STAMP=mystamp
1414
${CC} ${CFLAGS} ${INCFLAGS} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o
1515
${CC} ${CFLAGS} ${INCFLAGS} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o
1616
${CC} ${CFLAGS} ${INCFLAGS} -shared ${PROJECT}.o ${PROJECT}_bridge.o -o firmware/${PROJECT}-${LIB_STAMP}.so
17-
rm -f *.o
17+
rm -f *.o

0 commit comments

Comments
 (0)