Skip to content

Commit 3b6b2b1

Browse files
jmduartejmitrevsthesps
authored
Fix GlobalPooling1D Layers (fastmachinelearning#399)
* add global_pooling1d_cl * io_parallel global_pooling1d_cl * add globalpooling1d testing; add missing include in nnet_conv_stream.h * update test * Update test_globalpooling1d.py * Change project directory name Co-authored-by: Jovan Mitrevski <[email protected]> Co-authored-by: Sioni Summers <[email protected]>
1 parent 9fbb85d commit 3b6b2b1

File tree

7 files changed

+149
-16
lines changed

7 files changed

+149
-16
lines changed

hls4ml/converters/keras/pooling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def parse_pooling_layer(keras_layer, input_names, input_shapes, data_reader, con
7272

7373
return layer, output_shape
7474

75-
pooling_layers = ['GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
76-
@keras_handler(*pooling_layers)
75+
global_pooling_layers = ['GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
76+
@keras_handler(*global_pooling_layers)
7777
def parse_global_pooling_layer(keras_layer, input_names, input_shapes, data_reader, config):
7878
assert('Pooling' in keras_layer['class_name'])
7979

hls4ml/model/hls_layers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ def config_cpp(self):
11661166
params['n_filt'] = self.get_output_variable().dim_names[1]
11671167
else:
11681168
params['n_in'] = self.get_input_variable().dim_names[1]
1169-
params['n_out'] = self.get_input_variable().dim_names[1]
1169+
params['n_out'] = self.get_output_variable().dim_names[1]
11701170
params['n_filt'] = self.get_output_variable().dim_names[0]
11711171

11721172
return self._config_template.format(**params)
@@ -1208,19 +1208,24 @@ def config_cpp(self):
12081208

12091209
class GlobalPooling1D(Layer):
12101210
def initialize(self):
1211-
shape = [self.attributes['n_out'], self.attributes['n_filt']]
1212-
dims = ['N_OUTPUTS_{}'.format(self.index), 'N_FILT_{}'.format(self.index)]
1211+
shape = [self.attributes['n_filt']]
1212+
dims = ['N_FILT_{}'.format(self.index)]
12131213
self.add_output_variable(shape, dims)
12141214
self.set_attr('pool_op', self.get_attr('class_name').split('Pooling')[0].replace('Global', ''))
12151215

12161216
def function_cpp(self):
12171217
params = self._default_function_params()
1218-
1218+
params['data_format'] = 'cf' if self.get_attr('data_format') == 'channels_first' else 'cl'
12191219
return [self._function_template.format(**params)]
12201220

12211221
def config_cpp(self):
12221222
params = self._default_config_params()
1223-
params['n_in'] = self.get_input_variable().size_cpp()
1223+
if self.get_attr('data_format') == 'channels_last':
1224+
params['n_in'] = self.get_input_variable().dim_names[0]
1225+
params['n_filt'] = self.get_input_variable().dim_names[1]
1226+
else:
1227+
params['n_in'] = self.get_input_variable().dim_names[1]
1228+
params['n_filt'] = self.get_input_variable().dim_names[0]
12241229

12251230
return self._config_template.format(**params)
12261231

hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ap_shift_reg.h"
55
#include "nnet_common.h"
66
#include "hls_stream.h"
7+
#include "nnet_dense.h"
78

89
namespace nnet {
910

hls4ml/templates/vivado/nnet_utils/nnet_pooling.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ struct pooling1d_config{
8787
// IO size
8888
static const unsigned n_in = 10;
8989
static const unsigned pool_width = 2;
90-
static const unsigned n_out = n_in / pool_width;
90+
static const unsigned stride_width = 2;
91+
static const unsigned n_out = (n_in - pool_width) / stride_width + 1;
9192
static const unsigned pad_left = 0;
9293
static const unsigned pad_right = 0;
9394
// Pooling function
@@ -141,6 +142,25 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
141142
}
142143
}
143144

145+
template<class data_T, class res_T, typename CONFIG_T>
146+
void global_pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONFIG_T::n_filt]) {
147+
assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
148+
assert(CONFIG_T::pool_width == CONFIG_T::stride_width);
149+
150+
// TODO partition the arrays according to the reuse factor
151+
const int limit = pool_op_limit_1d<CONFIG_T>();
152+
#pragma HLS ALLOCATION instances=pool_op limit=limit function
153+
154+
for(int ff = 0; ff < CONFIG_T::n_filt; ff++) {
155+
data_T pool[CONFIG_T::n_in];
156+
for(int jj = 0; jj < CONFIG_T::n_in; jj++) {
157+
pool[jj] = data[jj * CONFIG_T::n_filt + ff];
158+
}
159+
// do the pooling
160+
res[ff] = pool_op<data_T, CONFIG_T::n_in, CONFIG_T::pool_op>(pool);
161+
}
162+
}
163+
144164
struct pooling2d_config{
145165
// IO size
146166
static const unsigned in_height = 10;

hls4ml/templates/vivado/nnet_utils/nnet_pooling_stream.h

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,6 @@ T reduce_global_pool(T x, T y[N]) {
473473

474474
template<class data_T, class res_T, typename CONFIG_T>
475475
void compute_global_pool(
476-
const unsigned h_idx,
477-
const unsigned w_idx,
478476
const data_T& in_elem,
479477
typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]
480478
) {
@@ -516,7 +514,7 @@ void global_pooling2d_cl(
516514
ReadInputHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) {
517515
ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_filt); i_iw++) {
518516
#pragma HLS LOOP_FLATTEN
519-
compute_global_pool<data_T, res_T, CONFIG_T>(i_ih, i_iw, data.read(), data_window);
517+
compute_global_pool<data_T, res_T, CONFIG_T>(data.read(), data_window);
520518
}
521519
}
522520

@@ -548,6 +546,60 @@ void global_pooling2d_cl(
548546

549547
}
550548

549+
template<class data_T, class res_T, typename CONFIG_T>
550+
void global_pooling1d_cl(
551+
hls::stream<data_T> &data,
552+
hls::stream<res_T> &res
553+
) {
554+
assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0);
555+
assert(CONFIG_T::pool_width == CONFIG_T::stride_width);
556+
557+
typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt];
558+
#pragma HLS ARRAY_PARTITION variable=data_window complete
559+
560+
typename CONFIG_T::accum_t init = 0;
561+
if (CONFIG_T::pool_op == Max) {
562+
init = hls::numeric_limits<typename CONFIG_T::accum_t>::min();
563+
}
564+
565+
PoolInitLoop: for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) {
566+
#pragma HLS UNROLL
567+
data_window[i_init] = init;
568+
}
569+
570+
ReadInput: for (unsigned i_iw = 0; i_iw < CONFIG_T::n_in / (data_T::size / CONFIG_T::n_filt); i_iw++) {
571+
#pragma HLS LOOP_FLATTEN
572+
compute_global_pool<data_T, res_T, CONFIG_T>(data.read(), data_window);
573+
}
574+
575+
if (CONFIG_T::pool_op == Max) {
576+
MaxPoolRes: for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) {
577+
#pragma HLS PIPELINE
578+
579+
res_T res_pack;
580+
#pragma HLS DATA_PACK variable=res_pack
581+
MaxPoolPack: for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
582+
#pragma HLS UNROLL
583+
res_pack[i_pack] = data_window[i_pack];
584+
}
585+
res.write(res_pack);
586+
}
587+
} else {
588+
AvgPoolRes: for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) {
589+
#pragma HLS PIPELINE
590+
591+
res_T res_pack;
592+
#pragma HLS DATA_PACK variable=res_pack
593+
AvgPoolPack: for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) {
594+
#pragma HLS UNROLL
595+
res_pack[i_pack] = data_window[i_pack] / CONFIG_T::n_in;
596+
}
597+
res.write(res_pack);
598+
}
599+
}
600+
601+
}
602+
551603
}
552604

553605
#endif

hls4ml/templates/vivado_template.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,9 @@
169169

170170
global_pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{
171171
static const unsigned n_in = {n_in};
172-
static const unsigned n_out = {n_out};
173-
static const unsigned pad_left = {pad_left};
174-
static const unsigned pad_right = {pad_right};
175-
static const unsigned stride = {stride};
172+
static const unsigned n_filt = {n_filt};
176173
static const nnet::Pool_Op pool_op = nnet::{pool_op};
174+
static const unsigned reuse = {reuse};
177175
typedef {accum_t} accum_t;
178176
}};\n"""
179177

@@ -358,7 +356,7 @@
358356
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
359357
pooling1d_function_template = 'nnet::pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
360358
pooling2d_function_template = 'nnet::pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
361-
global_pooling1d_function_template = 'nnet::global_pooling1d<{input_t}, {config}>({input}, {output});'
359+
global_pooling1d_function_template = 'nnet::global_pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
362360
global_pooling2d_function_template = 'nnet::global_pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
363361
zeropad1d_function_template = 'nnet::zeropad1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
364362
zeropad2d_function_template = 'nnet::zeropad2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'

test/pytest/test_globalpooling1d.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
from tensorflow.keras.models import Sequential
3+
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
4+
import numpy as np
5+
import hls4ml
6+
7+
8+
in_shape = 8
9+
in_feat = 4
10+
atol = 5e-3
11+
12+
@pytest.fixture(scope='module')
13+
def data():
14+
X = np.random.rand(100, in_shape, in_feat)
15+
return X
16+
17+
18+
@pytest.fixture(scope='module')
19+
def keras_model_max():
20+
model = Sequential()
21+
model.add(GlobalMaxPooling1D(input_shape=(in_shape, in_feat)))
22+
model.compile()
23+
return model
24+
25+
@pytest.fixture(scope='module')
26+
def keras_model_ave():
27+
model = Sequential()
28+
model.add(GlobalAveragePooling1D(input_shape=(in_shape, in_feat)))
29+
model.compile()
30+
return model
31+
32+
33+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
34+
@pytest.mark.parametrize('model_type', ['max', 'ave'])
35+
def test_global_pool1d(keras_model_max, keras_model_ave, data, model_type, io_type):
36+
if model_type == 'ave':
37+
model = keras_model_ave
38+
else:
39+
model = keras_model_max
40+
config = hls4ml.utils.config_from_keras_model(model,
41+
default_precision='ap_fixed<32,1>',
42+
granularity='name')
43+
if model_type == 'ave':
44+
config['LayerName']['global_average_pooling1d']['accum_t'] = 'ap_fixed<32,6>'
45+
46+
hls_model = hls4ml.converters.convert_from_keras_model(model,
47+
hls_config=config,
48+
io_type=io_type,
49+
output_dir=f'hls4mlprj_globalplool1d_{model_type}_{io_type}',
50+
part='xcvu9p-flgb2104-2-i')
51+
hls_model.compile()
52+
53+
54+
# Predict
55+
y_keras = np.squeeze(model.predict(data))
56+
y_hls = hls_model.predict(data)
57+
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)

0 commit comments

Comments
 (0)