Skip to content

Commit 7186346

Browse files
committed
attempted fix for streaming normalize_binary_tanh and normalize_ternary_tanh
1 parent d3b3daa commit 7186346

File tree

3 files changed

+8
-14
lines changed

3 files changed

+8
-14
lines changed

hls4ml/backends/fpga/passes/bn_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{
1010
static const unsigned n_in = {n_in};
1111
static const unsigned n_filt = {n_filt};
12+
static const unsigned n_scale_bias = (n_filt == -1) ? n_in : n_filt;
1213
static const unsigned io_type = nnet::{iotype};
1314
static const unsigned reuse_factor = {reuse};
1415
}};\n"""

hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ void normalize(hls::stream<data_T> &data, hls::stream<res_T> &res, typename CONF
5151
// Merged Batch Normalization and Quantized Tanh
5252
// ****************************************************
5353
template <class data_T, typename CONFIG_T>
54-
void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_uint<1>, CONFIG_T::n_in>> &res,
55-
typename data_T::value_type threshold[CONFIG_T::n_in]) {
54+
void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_uint<1>, CONFIG_T::n_scale_bias>> &res,
55+
typename data_T::value_type threshold[CONFIG_T::n_scale_bias]) {
5656
#pragma HLS ARRAY_PARTITION variable=threshold complete
5757

5858
BinaryNormLoop:
5959
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
6060
#pragma HLS PIPELINE
6161

6262
data_T in_data = data.read();
63-
nnet::array<ap_uint<1>, CONFIG_T::n_in> out_data;
63+
nnet::array<ap_uint<1>, CONFIG_T::n_scale_bias> out_data;
6464
PRAGMA_DATA_PACK(out_data)
6565

6666
BatchNormPack:
@@ -74,9 +74,9 @@ void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap
7474
}
7575

7676
template <class data_T, typename CONFIG_T>
77-
void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_int<2>, CONFIG_T::n_in>> &res,
78-
typename data_T::value_type threshold_hi[CONFIG_T::n_in],
79-
typename data_T::value_type threshold_lo[CONFIG_T::n_in]) {
77+
void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_int<2>, CONFIG_T::n_scale_bias>> &res,
78+
typename data_T::value_type threshold_hi[CONFIG_T::n_scale_bias],
79+
typename data_T::value_type threshold_lo[CONFIG_T::n_scale_bias]) {
8080
#pragma HLS ARRAY_PARTITION variable=threshold_hi complete
8181
#pragma HLS ARRAY_PARTITION variable=threshold_lo complete
8282

@@ -85,7 +85,7 @@ void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<a
8585
#pragma HLS PIPELINE
8686

8787
data_T in_data = data.read();
88-
nnet::array<ap_int<2>, CONFIG_T::n_in> out_data;
88+
nnet::array<ap_int<2>, CONFIG_T::n_scale_bias> out_data;
8989
PRAGMA_DATA_PACK(out_data)
9090

9191
BatchNormPack:

test/pytest/test_binary_cnn.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,4 @@ def test_model2(backend, io_type):
9090
y = model2.predict(np.zeros((1, 28, 28, 1)))
9191
y_hls = hls_model.predict(np.zeros((1, 28, 28, 1)))
9292

93-
print(f"{y_hls=}")
94-
print(f"{y=}")
95-
9693
np.testing.assert_allclose(np.squeeze(y_hls), np.squeeze(y), rtol=1e-2, atol=0.01)
97-
98-
99-
if __name__ == "__main__":
100-
test_model2("Vivado", "io_stream")

0 commit comments

Comments
 (0)