Skip to content

Commit 8eb29fc

Browse files
Fix pooling layers when padding is applied from the left/top (#757)
* fix pooling layers when padding is applied from the left/top * run pre-commit * add fixes for vitis, fix average pooling, add tests * remove changes from pytorch parser * diff clean * Update requirements.txt --------- Co-authored-by: Javier Duarte <[email protected]>
1 parent 6db5f3e commit 8eb29fc

File tree

5 files changed

+152
-17
lines changed

5 files changed

+152
-17
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ sphinx>=3.2.1
1212
sphinx_contributors
1313
sphinx_github_changelog
1414
sphinx_rtd_theme
15+
tensorflow
1516
toposort>=1.5.0

hls4ml/templates/quartus/firmware/nnet_utils/nnet_pooling.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
149149
pool[pool_col] = pad_val<data_T, CONFIG_T::pool_op>();
150150
} else {
151151
// Current element is from input image
152-
pool[pool_col] = data[(inp_col + pool_col) * CONFIG_T::n_filt + filt];
152+
pool[pool_col] = data[(inp_col + pool_col - CONFIG_T::pad_left) * CONFIG_T::n_filt + filt];
153153
img_overlap++;
154154
}
155155
}
@@ -160,7 +160,8 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
160160

161161
// If the pool op is Average, the zero-padding needs to be removed from the results
162162
if (CONFIG_T::pool_op == Average)
163-
res[(inp_col / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *= (CONFIG_T::pool_width / img_overlap);
163+
res[(inp_col / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *=
164+
(static_cast<data_T>(CONFIG_T::pool_width) / img_overlap);
164165
}
165166
}
166167
}
@@ -258,8 +259,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
258259
} else {
259260
// Current element is from input image
260261
pool[pool_col * CONFIG_T::stride_width + pool_row] =
261-
data[(inp_col + pool_col) * CONFIG_T::in_width * CONFIG_T::n_filt +
262-
(inp_width + pool_row) * CONFIG_T::n_filt + filt];
262+
data[(inp_col + pool_col - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
263+
(inp_width + pool_row - CONFIG_T::pad_left) * CONFIG_T::n_filt + filt];
263264
img_overlap++;
264265
}
265266
}
@@ -275,7 +276,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
275276
if (CONFIG_T::pool_op == Average)
276277
res[(inp_col / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
277278
(inp_width / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *=
278-
(CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap);
279+
(static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) /
280+
img_overlap);
279281
}
280282
}
281283
}

hls4ml/templates/vitis/nnet_utils/nnet_pooling.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
123123
// Add padding
124124
pool[jj] = pad_val<data_T, CONFIG_T::pool_op>();
125125
} else {
126-
pool[jj] = data[(ii + jj) * CONFIG_T::n_filt + ff];
126+
pool[jj] = data[(ii + jj - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
127127
img_overlap++;
128128
}
129129
}
@@ -134,7 +134,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
134134
pool_op<data_T, CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
135135
// If the pool op is Average, the zero-padding needs to be removed from the results
136136
if (CONFIG_T::pool_op == Average) {
137-
data_T rescale = CONFIG_T::pool_width / img_overlap;
137+
data_T rescale = static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
138138
res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
139139
}
140140
}
@@ -226,7 +226,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
226226
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
227227
} else {
228228
pool[kk * CONFIG_T::stride_width + ll] =
229-
data[(ii + kk) * CONFIG_T::in_width * CONFIG_T::n_filt + (jj + ll) * CONFIG_T::n_filt + ff];
229+
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
230+
(jj + ll - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
230231
img_overlap++;
231232
}
232233
}
@@ -239,7 +240,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
239240
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
240241
// If the pool op is Average, the zero-padding needs to be removed from the results
241242
if (CONFIG_T::pool_op == Average) {
242-
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
243+
data_T rescale =
244+
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
243245
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
244246
(jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
245247
}
@@ -297,7 +299,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
297299
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
298300
// If the pool op is Average, the zero-padding needs to be removed from the results
299301
if (CONFIG_T::pool_op == Average) {
300-
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
302+
data_T rescale =
303+
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
301304
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) +
302305
ff * CONFIG_T::out_height * CONFIG_T::out_width] *= rescale;
303306
}

hls4ml/templates/vivado/nnet_utils/nnet_pooling.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
123123
// Add padding
124124
pool[jj] = pad_val<data_T, CONFIG_T::pool_op>();
125125
} else {
126-
pool[jj] = data[(ii + jj) * CONFIG_T::n_filt + ff];
126+
pool[jj] = data[(ii + jj - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
127127
img_overlap++;
128128
}
129129
}
@@ -134,7 +134,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
134134
pool_op<data_T, CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
135135
// If the pool op is Average, the zero-padding needs to be removed from the results
136136
if (CONFIG_T::pool_op == Average) {
137-
data_T rescale = CONFIG_T::pool_width / img_overlap;
137+
data_T rescale = static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
138138
res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
139139
}
140140
}
@@ -227,7 +227,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
227227
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
228228
} else {
229229
pool[kk * CONFIG_T::stride_width + ll] =
230-
data[(ii + kk) * CONFIG_T::in_width * CONFIG_T::n_filt + (jj + ll) * CONFIG_T::n_filt + ff];
230+
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
231+
(jj + ll - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
231232
img_overlap++;
232233
}
233234
}
@@ -240,7 +241,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
240241
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
241242
// If the pool op is Average, the zero-padding needs to be removed from the results
242243
if (CONFIG_T::pool_op == Average) {
243-
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
244+
data_T rescale =
245+
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
244246
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
245247
(jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
246248
}
@@ -284,8 +286,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
284286
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
285287
} else {
286288
pool[kk * CONFIG_T::stride_width + ll] =
287-
data[(ii + kk) * CONFIG_T::in_width + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll +
288-
jj];
289+
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width +
290+
ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + jj - CONFIG_T::pad_left];
289291
img_overlap++;
290292
}
291293
}
@@ -298,7 +300,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
298300
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
299301
// If the pool op is Average, the zero-padding needs to be removed from the results
300302
if (CONFIG_T::pool_op == Average) {
301-
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
303+
data_T rescale =
304+
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
302305
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) +
303306
ff * CONFIG_T::out_height * CONFIG_T::out_width] *= rescale;
304307
}

test/pytest/test_pooling.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
from tensorflow.keras.layers import AveragePooling1D, AveragePooling2D, MaxPooling1D, MaxPooling2D
6+
from tensorflow.keras.models import Sequential
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
12+
in_shape = 124
13+
in_filt = 5
14+
atol = 5e-3
15+
16+
17+
@pytest.fixture(scope='module')
18+
def data_1d():
19+
return np.random.rand(100, in_shape, in_filt)
20+
21+
22+
@pytest.fixture(scope='module')
23+
def keras_model_1d(request):
24+
model_type = request.param['model_type']
25+
pads = request.param['padding']
26+
model = Sequential()
27+
if model_type == 'avg':
28+
model.add(AveragePooling1D(pool_size=3, input_shape=(in_shape, in_filt), padding=pads))
29+
elif model_type == 'max':
30+
model.add(MaxPooling1D(pool_size=3, input_shape=(in_shape, in_filt), padding=pads))
31+
model.compile()
32+
return model, model_type, pads
33+
34+
35+
@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado'])
36+
@pytest.mark.parametrize(
37+
'keras_model_1d',
38+
[
39+
{'model_type': 'max', 'padding': 'valid'},
40+
{'model_type': 'max', 'padding': 'same'},
41+
{'model_type': 'avg', 'padding': 'valid'},
42+
{'model_type': 'avg', 'padding': 'same'},
43+
],
44+
ids=[
45+
'model_type-max-padding-valid',
46+
'model_type-max-padding-same',
47+
'model_type-avg-padding-valid',
48+
'model_type-avg-padding-same',
49+
],
50+
indirect=True,
51+
)
52+
@pytest.mark.parametrize('io_type', ['io_parallel'])
53+
def test_pool1d(backend, keras_model_1d, data_1d, io_type):
54+
55+
model, model_type, padding = keras_model_1d
56+
57+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')
58+
59+
hls_model = hls4ml.converters.convert_from_keras_model(
60+
model,
61+
hls_config=config,
62+
io_type=io_type,
63+
output_dir=str(test_root_path / f'hls4mlprj_globalplool1d_{backend}_{io_type}_{model_type}_padding_{padding}'),
64+
backend=backend,
65+
)
66+
hls_model.compile()
67+
68+
y_keras = model.predict(data_1d)
69+
y_hls = hls_model.predict(data_1d).reshape(y_keras.shape)
70+
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)
71+
72+
73+
@pytest.fixture(scope='module')
74+
def data_2d():
75+
return np.random.rand(100, in_shape, in_shape, in_filt)
76+
77+
78+
@pytest.fixture(scope='module')
79+
def keras_model_2d(request):
80+
model_type = request.param['model_type']
81+
pads = request.param['padding']
82+
model = Sequential()
83+
if model_type == 'avg':
84+
model.add(AveragePooling2D(input_shape=(in_shape, in_shape, in_filt), padding=pads))
85+
elif model_type == 'max':
86+
model.add(MaxPooling2D(input_shape=(in_shape, in_shape, in_filt), padding=pads))
87+
model.compile()
88+
return model, model_type, pads
89+
90+
91+
@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado'])
92+
@pytest.mark.parametrize(
93+
'keras_model_2d',
94+
[
95+
{'model_type': 'max', 'padding': 'valid'},
96+
{'model_type': 'max', 'padding': 'same'},
97+
{'model_type': 'avg', 'padding': 'valid'},
98+
{'model_type': 'avg', 'padding': 'same'},
99+
],
100+
ids=[
101+
'model_type-max-padding-valid',
102+
'model_type-max-padding-same',
103+
'model_type-avg-padding-valid',
104+
'model_type-avg-padding-same',
105+
],
106+
indirect=True,
107+
)
108+
@pytest.mark.parametrize('io_type', ['io_parallel'])
109+
def test_pool2d(backend, keras_model_2d, data_2d, io_type):
110+
111+
model, model_type, padding = keras_model_2d
112+
113+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')
114+
115+
hls_model = hls4ml.converters.convert_from_keras_model(
116+
model,
117+
hls_config=config,
118+
io_type=io_type,
119+
output_dir=str(test_root_path / f'hls4mlprj_globalplool2d_{backend}_{io_type}_{model_type}_padding_{padding}'),
120+
backend=backend,
121+
)
122+
hls_model.compile()
123+
124+
y_keras = model.predict(data_2d)
125+
y_hls = hls_model.predict(data_2d).reshape(y_keras.shape)
126+
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)

0 commit comments

Comments
 (0)