Skip to content

Commit 856e778

Browse files
authored
CNNs with binary inputs and weights need fixes (#749)
* fix cast in remaining places for binary CNNs * add pytest for binary cnn * attempted fix for streaming normalize_binary_tanh and normalize_ternary_tanh * make all compile, though test differences are still too large * update pytest, disable comparison for now * remove setting of precision in max pool * specify the full path out test output
1 parent 1002f3e commit 856e778

File tree

10 files changed

+122
-36
lines changed

10 files changed

+122
-36
lines changed

hls4ml/backends/fpga/passes/bn_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{
1010
static const unsigned n_in = {n_in};
1111
static const unsigned n_filt = {n_filt};
12+
static const unsigned n_scale_bias = (n_filt == -1) ? n_in : n_filt;
1213
static const unsigned io_type = nnet::{iotype};
1314
static const unsigned reuse_factor = {reuse};
1415
}};\n"""

hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct batchnorm_config {
1515
// Layer Sizes
1616
static const unsigned n_in = 10;
1717
static const unsigned n_filt = -1;
18+
static const unsigned n_scale_bias = 10;
1819

1920
// Resource reuse info
2021
static const unsigned io_type = io_parallel;
@@ -29,8 +30,8 @@ struct batchnorm_config {
2930

3031
template <class data_T, class res_T, typename CONFIG_T>
3132
void normalize(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in],
32-
const typename CONFIG_T::scale_t scale[CONFIG_T::n_in],
33-
const typename CONFIG_T::bias_t bias[CONFIG_T::n_in]) {
33+
const typename CONFIG_T::scale_t scale[CONFIG_T::n_scale_bias],
34+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_scale_bias]) {
3435
// Calcuate result
3536
Result:
3637
#pragma unroll
@@ -54,6 +55,7 @@ struct batchnorm_quantized_tanh_config {
5455
// Layer Sizes
5556
static const unsigned n_in = 10;
5657
static const unsigned n_filt = -1;
58+
static const unsigned n_scale_bias = 10;
5759

5860
// Resource reuse info
5961
static const unsigned io_type = io_parallel;
@@ -63,34 +65,37 @@ struct batchnorm_quantized_tanh_config {
6365

6466
template <class data_T, typename CONFIG_T>
6567
void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ac_int<1, false> res[CONFIG_T::n_in],
66-
const data_T threshold[CONFIG_T::n_in]) {
68+
const data_T threshold[CONFIG_T::n_scale_bias]) {
6769
#pragma unroll
6870
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
6971
ac_int<1, false> cache;
7072
data_T datareg = data[ii];
71-
if (datareg > threshold[ii])
73+
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
74+
if (datareg > threshold[norm_index])
7275
cache = 1;
7376
else
7477
cache = 0;
7578

76-
res[ii] = (ac_int<1, false>)cache;
79+
res[ii] = cache;
7780
}
7881
}
7982

8083
template <class data_T, typename CONFIG_T>
8184
void normalize_ternary_tanh(data_T data[CONFIG_T::n_in], ac_int<2, true> res[CONFIG_T::n_in],
82-
const data_T threshold_hi[CONFIG_T::n_in], const data_T threshold_lo[CONFIG_T::n_in]) {
85+
const data_T threshold_hi[CONFIG_T::n_scale_bias],
86+
const data_T threshold_lo[CONFIG_T::n_scale_bias]) {
8387
#pragma unroll
8488
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
8589
ac_int<2, true> cache;
8690
data_T datareg = data[ii];
87-
if (datareg > threshold_hi[ii])
91+
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
92+
if (datareg > threshold_hi[norm_index])
8893
cache = 1;
89-
else if (datareg <= threshold_lo[ii])
94+
else if (datareg <= threshold_lo[norm_index])
9095
cache = -1;
9196
else
9297
cache = 0;
93-
res[ii] = (ac_int<2, true>)cache;
98+
res[ii] = cache;
9499
}
95100
}
96101

hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ namespace nnet {
1212
// Streaming Batch Normalization
1313
// ****************************************************
1414
template <class data_T, class res_T, typename CONFIG_T>
15-
void normalize(stream<data_T> &data, stream<res_T> &res, const typename CONFIG_T::scale_t scale[CONFIG_T::n_in],
16-
const typename CONFIG_T::bias_t bias[CONFIG_T::n_in]) {
15+
void normalize(stream<data_T> &data, stream<res_T> &res, const typename CONFIG_T::scale_t scale[CONFIG_T::n_scale_bias],
16+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_scale_bias]) {
1717

1818
constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor);
1919
constexpr unsigned pipeline = CONFIG_T::n_in / multiplier_limit;
@@ -46,14 +46,14 @@ void normalize(stream<data_T> &data, stream<res_T> &res, const typename CONFIG_T
4646
// Merged Batch Normalization and Quantized Tanh
4747
// ****************************************************
4848
template <class data_T, typename CONFIG_T>
49-
void normalize_binary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<1, false>, CONFIG_T::n_in>> &res,
50-
const typename data_T::value_type threshold[CONFIG_T::n_in]) {
49+
void normalize_binary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<1, false>, CONFIG_T::n_scale_bias>> &res,
50+
const typename data_T::value_type threshold[CONFIG_T::n_scale_bias]) {
5151

5252
BinaryNormLoop:
5353
#pragma ii 1
5454
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
5555
data_T in_data = data.read();
56-
nnet::array<ac_int<1, false>, CONFIG_T::n_in> out_data;
56+
nnet::array<ac_int<1, false>, CONFIG_T::n_scale_bias> out_data;
5757

5858
BatchNormPack:
5959
#pragma unroll
@@ -66,15 +66,15 @@ void normalize_binary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<1, fa
6666
}
6767

6868
template <class data_T, typename CONFIG_T>
69-
void normalize_ternary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<2, true>, CONFIG_T::n_in>> &res,
70-
const typename data_T::value_type threshold_hi[CONFIG_T::n_in],
71-
const typename data_T::value_type threshold_lo[CONFIG_T::n_in]) {
69+
void normalize_ternary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<2, true>, CONFIG_T::n_scale_bias>> &res,
70+
const typename data_T::value_type threshold_hi[CONFIG_T::n_scale_bias],
71+
const typename data_T::value_type threshold_lo[CONFIG_T::n_scale_bias]) {
7272

7373
TernaryNormLoop:
7474
#pragma ii 1
7575
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
7676
data_T in_data = data.read();
77-
nnet::array<ac_int<2, true>, CONFIG_T::n_in> out_data;
77+
nnet::array<ac_int<2, true>, CONFIG_T::n_scale_bias> out_data;
7878

7979
BatchNormPack:
8080
#pragma unroll

hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,27 @@ template <class x_T, class w_T> class weight_exponential : public Product {
8989
};
9090
} // namespace product
9191

92+
// TO-DO: These may need extra variants if ac_int types are used in more places
9293
template <class data_T, class res_T, typename CONFIG_T>
9394
inline typename std::enable_if<std::is_same<data_T, ac_int<1, false>>::value &&
9495
std::is_same<typename CONFIG_T::weight_t, ac_int<1, false>>::value,
9596
ac_int<nnet::ceillog2(CONFIG_T::n_in) + 2, true>>::type
9697
cast(typename CONFIG_T::accum_t x) {
97-
return (ac_int<nnet::ceillog2(CONFIG_T::n_in) + 2, true>)(x - CONFIG_T::n_in / 2) * 2;
98+
return static_cast<ac_int<nnet::ceillog2(CONFIG_T::n_in) + 2, true>>(((x - CONFIG_T::n_in / 2) * 2).to_ac_int());
9899
}
99100

100101
template <class data_T, class res_T, typename CONFIG_T>
101102
inline typename std::enable_if<std::is_same<data_T, ac_int<1, false>>::value &&
102103
!std::is_same<typename CONFIG_T::weight_t, ac_int<1, false>>::value,
103104
res_T>::type
104105
cast(typename CONFIG_T::accum_t x) {
105-
return (res_T)x;
106+
return static_cast<res_T>(x);
106107
}
107108

108109
template <class data_T, class res_T, typename CONFIG_T>
109110
inline typename std::enable_if<(!std::is_same<data_T, ac_int<1, false>>::value), res_T>::type
110111
cast(typename CONFIG_T::accum_t x) {
111-
return (res_T)x;
112+
return static_cast<res_T>(x);
112113
}
113114

114115
} // namespace nnet

hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct batchnorm_quantized_tanh_config {
6969
// Layer Sizes
7070
static const unsigned n_in = 10;
7171
static const unsigned n_filt = -1;
72+
static const unsigned n_scale_bias = 10;
7273

7374
// Resource reuse info
7475
static const unsigned io_type = io_parallel;
@@ -77,7 +78,8 @@ struct batchnorm_quantized_tanh_config {
7778
};
7879

7980
template <class data_T, typename CONFIG_T>
80-
void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ap_uint<1> res[CONFIG_T::n_in], data_T threshold[CONFIG_T::n_in]) {
81+
void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ap_uint<1> res[CONFIG_T::n_in],
82+
data_T threshold[CONFIG_T::n_scale_bias]) {
8183
#pragma HLS PIPELINE
8284
#pragma HLS ARRAY_PARTITION variable=res complete
8385

@@ -91,13 +93,13 @@ void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ap_uint<1> res[CONFIG_T:
9193
else
9294
cache = 0;
9395

94-
res[ii] = (ap_uint<1>)cache;
96+
res[ii] = cache;
9597
}
9698
}
9799

98100
template <class data_T, typename CONFIG_T>
99-
void normalize_ternary_tanh(data_T data[CONFIG_T::n_in], ap_int<2> res[CONFIG_T::n_in], data_T threshold_hi[CONFIG_T::n_in],
100-
data_T threshold_lo[CONFIG_T::n_in]) {
101+
void normalize_ternary_tanh(data_T data[CONFIG_T::n_in], ap_int<2> res[CONFIG_T::n_in],
102+
data_T threshold_hi[CONFIG_T::n_scale_bias], data_T threshold_lo[CONFIG_T::n_scale_bias]) {
101103
#pragma HLS PIPELINE
102104
#pragma HLS ARRAY_PARTITION variable=res complete
103105

@@ -113,7 +115,7 @@ void normalize_ternary_tanh(data_T data[CONFIG_T::n_in], ap_int<2> res[CONFIG_T:
113115
else
114116
cache = 0;
115117

116-
res[ii] = (ap_int<2>)cache;
118+
res[ii] = cache;
117119
}
118120
}
119121

hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ void normalize(hls::stream<data_T> &data, hls::stream<res_T> &res, typename CONF
5151
// Merged Batch Normalization and Quantized Tanh
5252
// ****************************************************
5353
template <class data_T, typename CONFIG_T>
54-
void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_uint<1>, CONFIG_T::n_in>> &res,
55-
typename data_T::value_type threshold[CONFIG_T::n_in]) {
54+
void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_uint<1>, CONFIG_T::n_scale_bias>> &res,
55+
typename data_T::value_type threshold[CONFIG_T::n_scale_bias]) {
5656
#pragma HLS ARRAY_PARTITION variable=threshold complete
5757

5858
BinaryNormLoop:
5959
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
6060
#pragma HLS PIPELINE
6161

6262
data_T in_data = data.read();
63-
nnet::array<ap_uint<1>, CONFIG_T::n_in> out_data;
63+
nnet::array<ap_uint<1>, CONFIG_T::n_scale_bias> out_data;
6464
PRAGMA_DATA_PACK(out_data)
6565

6666
BatchNormPack:
@@ -74,9 +74,9 @@ void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap
7474
}
7575

7676
template <class data_T, typename CONFIG_T>
77-
void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_int<2>, CONFIG_T::n_in>> &res,
78-
typename data_T::value_type threshold_hi[CONFIG_T::n_in],
79-
typename data_T::value_type threshold_lo[CONFIG_T::n_in]) {
77+
void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap_int<2>, CONFIG_T::n_scale_bias>> &res,
78+
typename data_T::value_type threshold_hi[CONFIG_T::n_scale_bias],
79+
typename data_T::value_type threshold_lo[CONFIG_T::n_scale_bias]) {
8080
#pragma HLS ARRAY_PARTITION variable=threshold_hi complete
8181
#pragma HLS ARRAY_PARTITION variable=threshold_lo complete
8282

@@ -85,7 +85,7 @@ void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<a
8585
#pragma HLS PIPELINE
8686

8787
data_T in_data = data.read();
88-
nnet::array<ap_int<2>, CONFIG_T::n_in> out_data;
88+
nnet::array<ap_int<2>, CONFIG_T::n_scale_bias> out_data;
8989
PRAGMA_DATA_PACK(out_data)
9090

9191
BatchNormPack:

hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
7272
// Cast to "res_t" type
7373
Result:
7474
for (int i_res = 0; i_res < mult_n_out; i_res++) {
75-
*(res++) = cast<data_T, res_T, CONFIG_T>(acc[i_res]);
75+
*(res++) = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_res]);
7676
}
7777
}
7878
}

hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void conv_2d_latency_cl(
7373
// Cast to "res_t" type
7474
Result:
7575
for (int i_res = 0; i_res < mult_n_out; i_res++) {
76-
*(res++) = cast<data_T, res_T, CONFIG_T>(acc[i_res]);
76+
*(res++) = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_res]);
7777
}
7878
}
7979
}

hls4ml/templates/vivado/nnet_utils/nnet_sepconv_stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void depthwise_product(data_T data[CONFIG_T::kernel_size * CONFIG_T::n_chan], re
5454
Result:
5555
for (int ires = 0; ires < CONFIG_T::n_chan; ires++) {
5656
#pragma HLS UNROLL
57-
res[ires] = cast<data_T, res_T, CONFIG_T>(acc[ires]);
57+
res[ires] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ires]);
5858
}
5959
}
6060

test/pytest/test_binary_cnn.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
from qkeras import QActivation, QBatchNormalization, QConv2D, QDense
6+
from tensorflow.keras.layers import Flatten, Input, MaxPooling2D
7+
from tensorflow.keras.models import Model
8+
from tensorflow.keras.regularizers import l2
9+
10+
import hls4ml
11+
12+
test_root_path = Path(__file__).parent
13+
14+
15+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
16+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
17+
def test_model2(backend, io_type):
18+
x_in = Input(shape=(28, 28, 1))
19+
20+
x = QConv2D(4, (3, 3), kernel_quantizer="binary", name="conv2d_1", kernel_regularizer=l2(0.0001), use_bias=False)(x_in)
21+
x = QBatchNormalization()(x)
22+
x = QActivation("binary", name="act1")(x)
23+
24+
x = QConv2D(8, (3, 3), kernel_quantizer="binary", name="conv2d_2", kernel_regularizer=l2(0.0001), use_bias=False)(x)
25+
x = QBatchNormalization()(x)
26+
x = QActivation("binary", name="act2")(x)
27+
x = MaxPooling2D(pool_size=(2, 2))(x)
28+
29+
x = QConv2D(8, (3, 3), kernel_quantizer="binary", name="conv2d_3", kernel_regularizer=l2(0.0001), use_bias=False)(x)
30+
x = QBatchNormalization()(x)
31+
x = QActivation("binary", name="act3")(x)
32+
x = MaxPooling2D(pool_size=(2, 2))(x)
33+
34+
x = Flatten()(x)
35+
36+
x = QDense(10, kernel_quantizer="binary", name="q_dense_6", use_bias=False)(x)
37+
x = QBatchNormalization()(x)
38+
x = QActivation("binary_tanh", name="act4")(x)
39+
40+
x = QDense(10, kernel_quantizer="binary", activation="softmax", name="q_dense_7", use_bias=False)(x)
41+
42+
model2 = Model(inputs=x_in, outputs=x)
43+
44+
model2.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
45+
46+
model2.summary()
47+
48+
hls_config = hls4ml.utils.config_from_keras_model(model2, granularity="name")
49+
hls_config["Model"]["Strategy"] = "Resource"
50+
51+
print(f"{hls_config['LayerName'].keys()=}")
52+
for layer in hls_config['LayerName'].keys():
53+
hls_config['LayerName'][layer]['Strategy'] = "Latency"
54+
55+
hls_config["LayerName"]["conv2d_1"]["ReuseFactor"] = 36
56+
hls_config["LayerName"]["conv2d_2"]["ReuseFactor"] = 288
57+
hls_config["LayerName"]["conv2d_3"]["ReuseFactor"] = 576
58+
hls_config["LayerName"]["q_dense_6"]["ReuseFactor"] = 2000
59+
hls_config["LayerName"]["q_dense_7"]["ReuseFactor"] = 100
60+
61+
output_dir = str(test_root_path / f"hls4mlprj_binary_cnn_{backend}_{io_type}")
62+
hls_model = hls4ml.converters.convert_from_keras_model(
63+
model2,
64+
hls_config=hls_config,
65+
output_dir=output_dir,
66+
backend=backend,
67+
io_type=io_type,
68+
)
69+
70+
X = np.random.rand(1, 28, 28, 1)
71+
72+
hls_model.compile()
73+
y = model2.predict(X) # noqa: F841
74+
y_hls = hls_model.predict(X) # noqa: F841
75+
76+
# # TODO: enable the comparions after fixing the remaing issues
77+
# np.testing.assert_allclose(np.squeeze(y_hls), np.squeeze(y), rtol=1e-2, atol=0.01)

0 commit comments

Comments
 (0)