Skip to content

Commit d035738

Browse files
qberthetQuentin Berthetvloncar
authored
Fix config structure name in pragma for SeparableConv1D (#884)
* Raise exception if Vivado command fail * Duplicate sepconv2d test for sepconv1d * Test that csynth is working for sepconv1d * Define multiplier_limit in nnet::conv1d_config (for sepconv1d) * Revert build test --------- Co-authored-by: Quentin Berthet <[email protected]> Co-authored-by: Vladimir Loncar <[email protected]>
1 parent 568f97f commit d035738

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
static const unsigned out_width = {out_width};
4242
static const unsigned reuse_factor = {reuse};
4343
static const unsigned n_zeros = {nzeros};
44+
static const unsigned multiplier_limit =
45+
DIV_ROUNDUP(kernel_size * n_chan * n_filt, reuse_factor) - n_zeros / reuse_factor;
4446
static const bool store_weights_in_bram = false;
4547
static const unsigned strategy = nnet::{strategy};
4648
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};

test/pytest/test_sepconv1d.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import tensorflow as tf
6+
from tensorflow.keras.layers import SeparableConv1D
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
12+
keras_conv1d = [SeparableConv1D]
13+
padds_options = ['same', 'valid']
14+
chans_options = ['channels_last']
15+
io_type_options = ['io_stream']
16+
strides_options = [(1), (2)]
17+
kernel_options = [(1), (3)]
18+
bias_options = [False]
19+
20+
21+
@pytest.mark.parametrize('conv1d', keras_conv1d)
22+
@pytest.mark.parametrize('chans', chans_options)
23+
@pytest.mark.parametrize('padds', padds_options)
24+
@pytest.mark.parametrize('strides', strides_options)
25+
@pytest.mark.parametrize('kernels', kernel_options)
26+
@pytest.mark.parametrize('bias', bias_options)
27+
@pytest.mark.parametrize('io_type', io_type_options)
28+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
29+
def test_sepconv1d(conv1d, chans, padds, strides, kernels, bias, io_type, backend):
30+
model = tf.keras.models.Sequential()
31+
input_shape = (28, 3)
32+
model.add(
33+
conv1d(
34+
filters=32,
35+
kernel_size=kernels,
36+
strides=strides,
37+
padding=padds,
38+
input_shape=input_shape,
39+
kernel_initializer='normal',
40+
use_bias=bias,
41+
data_format=chans,
42+
)
43+
)
44+
45+
model.compile(optimizer='adam', loss='mse')
46+
X_input = np.random.rand(100, *input_shape)
47+
keras_prediction = model.predict(X_input)
48+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>')
49+
stride_cfg = str(strides).replace(', ', '_').replace('(', '').replace(')', '')
50+
kernel_cfg = str(kernels).replace(', ', '_').replace('(', '').replace(')', '')
51+
output_dir = str(
52+
test_root_path
53+
/ 'hls4mlprj_{}_{}_strides_{}_kernels_{}_{}_padding_{}_{}'.format(
54+
conv1d.__name__.lower(), chans, stride_cfg, kernel_cfg, padds, backend, io_type
55+
)
56+
)
57+
hls_model = hls4ml.converters.convert_from_keras_model(
58+
model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend
59+
)
60+
hls_model.compile()
61+
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
62+
63+
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)

0 commit comments

Comments
 (0)