Skip to content

Fix pooling layers when padding is applied from the left/top #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ sphinx>=3.2.1
sphinx_contributors
sphinx_github_changelog
sphinx_rtd_theme
tensorflow
toposort>=1.5.0
12 changes: 7 additions & 5 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
pool[pool_col] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
// Current element is from input image
pool[pool_col] = data[(inp_col + pool_col) * CONFIG_T::n_filt + filt];
pool[pool_col] = data[(inp_col + pool_col - CONFIG_T::pad_left) * CONFIG_T::n_filt + filt];
img_overlap++;
}
}
Expand All @@ -160,7 +160,8 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF

// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average)
res[(inp_col / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *= (CONFIG_T::pool_width / img_overlap);
res[(inp_col / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *=
(static_cast<data_T>(CONFIG_T::pool_width) / img_overlap);
}
}
}
Expand Down Expand Up @@ -258,8 +259,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
} else {
// Current element is from input image
pool[pool_col * CONFIG_T::stride_width + pool_row] =
data[(inp_col + pool_col) * CONFIG_T::in_width * CONFIG_T::n_filt +
(inp_width + pool_row) * CONFIG_T::n_filt + filt];
data[(inp_col + pool_col - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
(inp_width + pool_row - CONFIG_T::pad_left) * CONFIG_T::n_filt + filt];
img_overlap++;
}
}
Expand All @@ -275,7 +276,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
if (CONFIG_T::pool_op == Average)
res[(inp_col / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
(inp_width / CONFIG_T::stride_width) * CONFIG_T::n_filt + filt] *=
(CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap);
(static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) /
img_overlap);
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions hls4ml/templates/vitis/nnet_utils/nnet_pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
// Add padding
pool[jj] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
pool[jj] = data[(ii + jj) * CONFIG_T::n_filt + ff];
pool[jj] = data[(ii + jj - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
img_overlap++;
}
}
Expand All @@ -134,7 +134,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
pool_op<data_T, CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_width / img_overlap;
data_T rescale = static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
}
}
Expand Down Expand Up @@ -226,7 +226,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
pool[kk * CONFIG_T::stride_width + ll] =
data[(ii + kk) * CONFIG_T::in_width * CONFIG_T::n_filt + (jj + ll) * CONFIG_T::n_filt + ff];
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
(jj + ll - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
img_overlap++;
}
}
Expand All @@ -239,7 +240,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
data_T rescale =
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
(jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
}
Expand Down Expand Up @@ -297,7 +299,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
data_T rescale =
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) +
ff * CONFIG_T::out_height * CONFIG_T::out_width] *= rescale;
}
Expand Down
17 changes: 10 additions & 7 deletions hls4ml/templates/vivado/nnet_utils/nnet_pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
// Add padding
pool[jj] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
pool[jj] = data[(ii + jj) * CONFIG_T::n_filt + ff];
pool[jj] = data[(ii + jj - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
img_overlap++;
}
}
Expand All @@ -134,7 +134,7 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF
pool_op<data_T, CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_width / img_overlap;
data_T rescale = static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
}
}
Expand Down Expand Up @@ -227,7 +227,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
pool[kk * CONFIG_T::stride_width + ll] =
data[(ii + kk) * CONFIG_T::in_width * CONFIG_T::n_filt + (jj + ll) * CONFIG_T::n_filt + ff];
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt +
(jj + ll - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff];
img_overlap++;
}
}
Expand All @@ -240,7 +241,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
data_T rescale =
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt +
(jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale;
}
Expand Down Expand Up @@ -284,8 +286,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool[kk * CONFIG_T::stride_width + ll] = pad_val<data_T, CONFIG_T::pool_op>();
} else {
pool[kk * CONFIG_T::stride_width + ll] =
data[(ii + kk) * CONFIG_T::in_width + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll +
jj];
data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width +
ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + jj - CONFIG_T::pad_left];
img_overlap++;
}
}
Expand All @@ -298,7 +300,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_
pool_op<data_T, CONFIG_T::pool_height * CONFIG_T::pool_width, CONFIG_T::pool_op>(pool);
// If the pool op is Average, the zero-padding needs to be removed from the results
if (CONFIG_T::pool_op == Average) {
data_T rescale = CONFIG_T::pool_height * CONFIG_T::pool_width / img_overlap;
data_T rescale =
static_cast<data_T>(CONFIG_T::pool_height) * static_cast<data_T>(CONFIG_T::pool_width) / img_overlap;
res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) +
ff * CONFIG_T::out_height * CONFIG_T::out_width] *= rescale;
}
Expand Down
126 changes: 126 additions & 0 deletions test/pytest/test_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from pathlib import Path

import numpy as np
import pytest
from tensorflow.keras.layers import AveragePooling1D, AveragePooling2D, MaxPooling1D, MaxPooling2D
from tensorflow.keras.models import Sequential

import hls4ml

test_root_path = Path(__file__).parent

in_shape = 124
in_filt = 5
atol = 5e-3


@pytest.fixture(scope='module')
def data_1d():
return np.random.rand(100, in_shape, in_filt)


@pytest.fixture(scope='module')
def keras_model_1d(request):
model_type = request.param['model_type']
pads = request.param['padding']
model = Sequential()
if model_type == 'avg':
model.add(AveragePooling1D(pool_size=3, input_shape=(in_shape, in_filt), padding=pads))
elif model_type == 'max':
model.add(MaxPooling1D(pool_size=3, input_shape=(in_shape, in_filt), padding=pads))
model.compile()
return model, model_type, pads


@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado'])
@pytest.mark.parametrize(
'keras_model_1d',
[
{'model_type': 'max', 'padding': 'valid'},
{'model_type': 'max', 'padding': 'same'},
{'model_type': 'avg', 'padding': 'valid'},
{'model_type': 'avg', 'padding': 'same'},
],
ids=[
'model_type-max-padding-valid',
'model_type-max-padding-same',
'model_type-avg-padding-valid',
'model_type-avg-padding-same',
],
indirect=True,
)
@pytest.mark.parametrize('io_type', ['io_parallel'])
def test_pool1d(backend, keras_model_1d, data_1d, io_type):

model, model_type, padding = keras_model_1d

config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')

hls_model = hls4ml.converters.convert_from_keras_model(
model,
hls_config=config,
io_type=io_type,
output_dir=str(test_root_path / f'hls4mlprj_globalplool1d_{backend}_{io_type}_{model_type}_padding_{padding}'),
backend=backend,
)
hls_model.compile()

y_keras = model.predict(data_1d)
y_hls = hls_model.predict(data_1d).reshape(y_keras.shape)
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)


@pytest.fixture(scope='module')
def data_2d():
return np.random.rand(100, in_shape, in_shape, in_filt)


@pytest.fixture(scope='module')
def keras_model_2d(request):
model_type = request.param['model_type']
pads = request.param['padding']
model = Sequential()
if model_type == 'avg':
model.add(AveragePooling2D(input_shape=(in_shape, in_shape, in_filt), padding=pads))
elif model_type == 'max':
model.add(MaxPooling2D(input_shape=(in_shape, in_shape, in_filt), padding=pads))
model.compile()
return model, model_type, pads


@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado'])
@pytest.mark.parametrize(
'keras_model_2d',
[
{'model_type': 'max', 'padding': 'valid'},
{'model_type': 'max', 'padding': 'same'},
{'model_type': 'avg', 'padding': 'valid'},
{'model_type': 'avg', 'padding': 'same'},
],
ids=[
'model_type-max-padding-valid',
'model_type-max-padding-same',
'model_type-avg-padding-valid',
'model_type-avg-padding-same',
],
indirect=True,
)
@pytest.mark.parametrize('io_type', ['io_parallel'])
def test_pool2d(backend, keras_model_2d, data_2d, io_type):

model, model_type, padding = keras_model_2d

config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')

hls_model = hls4ml.converters.convert_from_keras_model(
model,
hls_config=config,
io_type=io_type,
output_dir=str(test_root_path / f'hls4mlprj_globalplool2d_{backend}_{io_type}_{model_type}_padding_{padding}'),
backend=backend,
)
hls_model.compile()

y_keras = model.predict(data_2d)
y_hls = hls_model.predict(data_2d).reshape(y_keras.shape)
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)