Skip to content

Commit 85b9531

Browse files
katyagovorkovajmitrevsvloncar
authored
WIP Add custom KL loss layer HLS implementation (#606)
* add kl layer * separate hls part; clean up and add docs * creeate KL layer folder in contrib and move the files there * pass pre-commit check * README and fix pre-commit issue * update readme * fix formatting * add readme * Update README.md @jmitrevs readme updated! * Update README.md remove trailing whitespace * Update kl_layer.py * Rename nnet_distance.h to kl_layer.h * Update README.md * Update kl_layer.py * Update kl_layer.h * fix pre-commit * Fix KLLoss layer example --------- Co-authored-by: Jovan Mitrevski <[email protected]> Co-authored-by: Vladimir Loncar <[email protected]>
1 parent 107589f commit 85b9531

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed

contrib/kl_layer/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
This folder contains the implementation of custom KL divergence layer.
2+
This is a custom implementation and not a built-in layer in any deep learning framework.
3+
It was developed specifically for [AD@L1 CMS paper](https://www.nature.com/articles/s42256-022-00441-3).
4+
5+
# Files
6+
7+
* `kl_layer.py`: contains the standalone implementation of the custom KL divergence layer
8+
* `kl_layer.h`: contains the HLS implementation of KL layer
9+
10+
11+
# Usage
12+
13+
`kl_layer.py` contains the example of how to use the KL layer.
14+
To run do
15+
16+
```
17+
python kl_layer.py
18+
```

contrib/kl_layer/kl_layer.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#ifndef KL_LAYER_H_
2+
#define KL_LAYER_H_
3+
4+
#include "nnet_activation.h"
5+
#include "nnet_common.h"
6+
#include <cmath>
7+
#include <cstdlib>
8+
9+
namespace nnet {
10+
11+
struct distance_config {
12+
// IO size
13+
static const unsigned n_in = 10;
14+
static const unsigned n_out = 1;
15+
16+
// Internal data type definitions
17+
typedef float accum_t;
18+
typedef float sum_t;
19+
typedef ap_fixed<18, 8> exp_table_t;
20+
21+
// Internal info
22+
static const unsigned table_size = 1024;
23+
static constexpr unsigned exp_range = 8;
24+
};
25+
26+
template <typename CONFIG_T, int N_TABLE> void init_klloss_exp_table(typename CONFIG_T::exp_table_t table_out[N_TABLE]) {
27+
for (int ii = 0; ii < N_TABLE; ii++) {
28+
// First, convert from table index to X-value (range -1 to +1)
29+
float in_val = 2 * CONFIG_T::exp_range * (ii - float(N_TABLE) / 2.0) / float(N_TABLE);
30+
// Next, compute lookup table function
31+
typename CONFIG_T::exp_table_t real_val = exp_fcn_float(in_val);
32+
// std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << " Index: " << ii << std::endl;
33+
table_out[ii] = real_val;
34+
}
35+
}
36+
template <class data1_T, class data2_T, class res_T, typename CONFIG_T>
37+
void klloss(data1_T mean[CONFIG_T::n_in], data2_T log_var[CONFIG_T::n_in], res_T res[CONFIG_T::n_out]) {
38+
#pragma HLS PIPELINE
39+
// Initialize the lookup tables
40+
#ifdef __HLS_SYN__
41+
bool initialized = false;
42+
typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size];
43+
#else
44+
static bool initialized = false;
45+
static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size];
46+
#endif
47+
if (!initialized) {
48+
init_klloss_exp_table<CONFIG_T, CONFIG_T::table_size>(exp_table);
49+
initialized = true;
50+
}
51+
typename CONFIG_T::accum_t kl[CONFIG_T::n_in];
52+
#pragma HLS ARRAY_PARTITION variable=kl complete
53+
typename CONFIG_T::accum_t mean_sq[CONFIG_T::n_in];
54+
#pragma HLS ARRAY_PARTITION variable=mean_sq complete
55+
typename CONFIG_T::accum_t kl_sum(0);
56+
for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
57+
#pragma HLS UNROLL
58+
mean_sq[i] = mean[i] * mean[i];
59+
kl[i] = data2_T(1.) + log_var[i];
60+
// std::cout << "Log var: " << log_var[i] << " Result: " << kl[i] << std::endl;
61+
}
62+
constexpr unsigned table_scale = (unsigned)(CONFIG_T::table_size / (2 * CONFIG_T::exp_range));
63+
constexpr unsigned index_scale = (unsigned)(CONFIG_T::exp_range * table_scale);
64+
for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
65+
#pragma HLS UNROLL
66+
auto data_round = log_var[i] * table_scale;
67+
auto index = data_round + index_scale;
68+
if (index < 0)
69+
index = 0;
70+
if (index > CONFIG_T::table_size - 1)
71+
index = CONFIG_T::table_size - 1;
72+
kl[i] -= exp_table[index];
73+
// std::cout << "Exp var: " << exp_table[index] << " Result: " << kl[i] << " Index: " << index << std::endl;
74+
}
75+
for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
76+
#pragma HLS UNROLL
77+
kl[i] -= mean_sq[i];
78+
}
79+
Op_add<typename CONFIG_T::accum_t> op_add;
80+
kl_sum = reduce<typename CONFIG_T::accum_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::accum_t>>(kl, op_add);
81+
// std::cout << "KL sum: " << kl_sum << std::endl;
82+
kl_sum *= typename CONFIG_T::accum_t(1. / CONFIG_T::n_in);
83+
res[0] = res_T(-0.5) * kl_sum;
84+
}
85+
} // namespace nnet
86+
87+
#endif

contrib/kl_layer/kl_layer.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
Usage example for a custom KL loss layer
3+
Takes as an input two arrays: z_mean and z_log_var
4+
and computes KL "distance" between normal distribution
5+
and Gaussian with mu=z_mean and sigma=z_log_var
6+
7+
The HLS part is in contrib/kl_layer/kl_layer.h
8+
"""
9+
from pathlib import Path
10+
11+
import numpy as np
12+
import tensorflow as tf
13+
14+
try:
15+
from keras.layers.merge import _Merge as Merge
16+
except Exception:
17+
from keras.layers.merging.base_merge import _Merge as Merge
18+
19+
from tensorflow.python.keras.utils import tf_utils
20+
from tensorflow.python.ops import math_ops
21+
22+
import hls4ml
23+
from hls4ml.converters.keras_to_hls import parse_default_keras_layer
24+
from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute
25+
from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode
26+
27+
28+
# Keras implementation of a KL layer
29+
class KLLoss(Merge):
30+
'''Keras implementation of a KL loss custom layer'''
31+
32+
@tf_utils.shape_type_conversion
33+
def build(self, input_shape):
34+
super().build(input_shape)
35+
36+
def _merge_function(self, inputs):
37+
38+
mean = inputs[0]
39+
log_var = inputs[1]
40+
41+
kl = 1.0 + log_var - math_ops.square(mean) - math_ops.exp(log_var)
42+
kl = -0.5 * math_ops.reduce_mean(kl, axis=-1, keepdims=True)
43+
44+
return kl
45+
46+
47+
# hls4ml implementations
48+
class HKLLoss(hls4ml.model.layers.Layer):
49+
'''hls4ml implementation of a KL loss custom layer'''
50+
51+
_expected_attributes = [
52+
ConfigurableAttribute('table_size', default=1024),
53+
ConfigurableAttribute('exp_range', default=8),
54+
TypeAttribute('accum'),
55+
TypeAttribute(
56+
'sum',
57+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
58+
),
59+
TypeAttribute(
60+
'exp_table',
61+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
62+
),
63+
]
64+
65+
def initialize(self):
66+
self.add_output_variable(shape=[1], dim_names=[f'KL_LOSS_{self.index}'])
67+
68+
69+
# Templates
70+
distance_config_template = """struct config{index} : nnet::distance_config {{
71+
static const unsigned n_in = {n_in};
72+
static const unsigned n_out = 1;
73+
typedef {accum_t.name} accum_t;
74+
typedef {sum_t.name} sum_t;
75+
typedef {exp_table_t.name} exp_table_t;
76+
static const unsigned table_size = {table_size};
77+
static constexpr float exp_range = {exp_range};
78+
}};\n"""
79+
distance_function_template = 'nnet::klloss<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});'
80+
distance_include_list = ['nnet_utils/kl_layer.h']
81+
82+
83+
class HKLLossConfigTemplate(hls4ml.backends.template.LayerConfigTemplate):
84+
def __init__(self):
85+
super().__init__(HKLLoss)
86+
self.template = distance_config_template
87+
88+
def format(self, node):
89+
params = self._default_config_params(node)
90+
params['n_in'] = node.get_input_variable(node.inputs[0]).shape[0]
91+
params['n_out'] = 1
92+
return self.template.format(**params)
93+
94+
95+
class HKLLossFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate):
96+
def __init__(self):
97+
super().__init__(HKLLoss, include_header=distance_include_list)
98+
self.template = distance_function_template
99+
100+
def format(self, node):
101+
params = {}
102+
params['config'] = f'config{node.index}'
103+
params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name
104+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
105+
params['output_t'] = node.get_output_variable().type.name
106+
params['input1'] = node.get_input_variable(node.inputs[0]).name
107+
params['input2'] = node.get_input_variable(node.inputs[1]).name
108+
params['output'] = node.get_output_variable().name
109+
110+
return self.template.format(**params)
111+
112+
113+
# Parser for converter
114+
def parse_klloss_layer(keras_layer, input_names, input_shapes, data_reader):
115+
assert 'KLLoss' in keras_layer['class_name']
116+
117+
layer = parse_default_keras_layer(keras_layer, input_names)
118+
119+
output_shape = [input_shapes[0][0], 1]
120+
121+
return layer, output_shape
122+
123+
124+
def main():
125+
# Register the converter for custom Keras layer
126+
hls4ml.converters.register_keras_layer_handler('KLLoss', parse_klloss_layer)
127+
128+
# Register the hls4ml's IR layer
129+
hls4ml.model.layers.register_layer('KLLoss', HKLLoss)
130+
131+
# Register the optimization passes (if any)
132+
backend = hls4ml.backends.get_backend('Vivado')
133+
134+
# Register template passes for the given backend
135+
backend.register_template(HKLLossConfigTemplate)
136+
backend.register_template(HKLLossFunctionTemplate)
137+
138+
# Register HLS implementation
139+
p = Path(__file__).parent / 'kl_layer.h'
140+
backend.register_source(p)
141+
142+
# Test if it works
143+
# Create a dummy Keras model with KL loss layer
144+
inp = tf.keras.layers.Input(shape=(19, 3, 1))
145+
z_mean = tf.keras.layers.Dense(10)(inp)
146+
z_log_var = tf.keras.layers.Dense(10)(inp)
147+
custom_output = KLLoss()([z_mean, z_log_var])
148+
# create new model
149+
kmodel = tf.keras.models.Model(inputs=inp, outputs=custom_output)
150+
kmodel.summary()
151+
152+
# test on random inputs
153+
x = np.random.randint(-5, 5, (1, 19, 3, 1), dtype='int32')
154+
kres = kmodel(x)
155+
156+
# Create dummy config
157+
config = {}
158+
config['Model'] = {
159+
'Precision': 'ap_fixed<16,6>',
160+
'ReuseFactor': 1,
161+
'ParallelizationFactor': 1,
162+
'Strategy': 'Resource',
163+
}
164+
hmodel = hls4ml.converters.convert_from_keras_model(
165+
kmodel,
166+
output_dir='hls4mlprj_kl_layer',
167+
backend='Vivado',
168+
io_type='io_parallel',
169+
part='xcvu9p-flga2577-2-e',
170+
hls_config=config,
171+
)
172+
173+
hmodel.compile()
174+
hres = hmodel.predict(x.astype('float32'))
175+
176+
print('Compare prediction by hls4ml model to Keras one')
177+
print(kres - hres)
178+
179+
print('Building model')
180+
report = hmodel.build(reset=True, csim=False, cosim=True, synth=True, vsynth=True)
181+
print(report)
182+
183+
184+
if __name__ == '__main__':
185+
main()

0 commit comments

Comments
 (0)