Skip to content

Commit cecf3e7

Browse files
authored
Make binary CNN match between Keras and hls4ml (#804)
* make normalize_binary_tanh match qkeras * make normalize_binary_tanh match qkeras (stream) * Make maxpooling copy over XnorPrecision * enable actual checking * Fix for io-parallel latency for Vivado and Vitis backends * make same > to >= change for quartus as for vivado * fix normalize_tanh for io_stream * remove softmax from test since that complicates the test and it's not the focus * fix assertion limit for RF for im2col implementation * fix both_binary product for Quartus * turn off winograd in tests since it doesn't support binary * remove TODO that is done * increase size of accumulator in pytest * update test_binary with suggestions, without adding bias * Add bias to CNN where both the input and weights are not binary * change the default precision of the test
1 parent 490ac46 commit cecf3e7

File tree

10 files changed

+97
-28
lines changed

10 files changed

+97
-28
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from hls4ml.model.layers import GlobalPooling1D, GlobalPooling2D, Pooling1D, Pooling2D
2+
from hls4ml.model.optimizer import OptimizerPass
3+
from hls4ml.model.types import XnorPrecisionType
4+
5+
6+
class XnorPooling(OptimizerPass):
7+
'''
8+
For correct behavior, for MaxPooling and similar, for XnorPrecisionType, have to propagate
9+
the type to the output.
10+
'''
11+
12+
def match(self, node):
13+
if isinstance(node, (Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D)) and node.get_attr('pool_op') == 'Max':
14+
return isinstance(node.get_input_variable().type.precision, XnorPrecisionType) and not isinstance(
15+
node.get_output_variable().type.precision, XnorPrecisionType
16+
)
17+
return False
18+
19+
def transform(self, model, node):
20+
outvar = node.get_output_variable()
21+
outvar.type.precision = XnorPrecisionType()
22+
return True

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _register_flows(self):
6262
'quartus:merge_batch_norm_quantized_tanh',
6363
'quartus:quantize_dense_output',
6464
'fuse_consecutive_batch_normalization',
65+
'quartus:xnor_pooling',
6566
]
6667
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
6768

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _register_flows(self):
9292
'vivado:merge_batch_norm_quantized_tanh',
9393
'vivado:quantize_dense_output',
9494
'fuse_consecutive_batch_normalization',
95+
'vivado:xnor_pooling',
9596
]
9697
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
9798

hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ac_int<1, false> res[CON
7171
ac_int<1, false> cache;
7272
data_T datareg = data[ii];
7373
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
74-
if (datareg > threshold[norm_index])
74+
if (datareg >= threshold[norm_index])
7575
cache = 1;
7676
else
7777
cache = 0;

hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@ void normalize_binary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<1, fa
5858
BatchNormPack:
5959
#pragma unroll
6060
for (int j = 0; j < data_T::size; j++) {
61-
out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0;
61+
int norm_index;
62+
if (CONFIG_T::n_filt == -1)
63+
norm_index = i * data_T::size + j;
64+
else
65+
norm_index = j % CONFIG_T::n_filt;
66+
67+
out_data[j] = (in_data[j] >= threshold[norm_index]) ? 1 : 0;
6268
}
6369

6470
res.write(out_data);
@@ -79,7 +85,12 @@ void normalize_ternary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<2, t
7985
BatchNormPack:
8086
#pragma unroll
8187
for (int j = 0; j < data_T::size; j++) {
82-
int norm_index = i * data_T::size + j;
88+
int norm_index;
89+
if (CONFIG_T::n_filt == -1)
90+
norm_index = i * data_T::size + j;
91+
else
92+
norm_index = j % CONFIG_T::n_filt;
93+
8394
if (in_data[j] > threshold_hi[norm_index])
8495
out_data[j] = 1;
8596
else if (in_data[j] <= threshold_lo[norm_index])

hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ template <class x_T, class w_T> class both_binary : public Product {
1919
public:
2020
inline static x_T product(x_T a, w_T w) {
2121
// specialisation for 1-bit weights and incoming data
22-
return a & w;
22+
return a == w;
2323
}
2424
};
2525

hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ap_uint<1> res[CONFIG_T:
8888
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
8989
datareg = data[ii];
9090
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
91-
if (datareg > threshold[norm_index])
91+
if (datareg >= threshold[norm_index])
9292
cache = 1;
9393
else
9494
cache = 0;

hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,13 @@ void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap
6666
BatchNormPack:
6767
for (int j = 0; j < data_T::size; j++) {
6868
#pragma HLS UNROLL
69-
out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0;
69+
int norm_index;
70+
if (CONFIG_T::n_filt == -1) {
71+
norm_index = i * data_T::size + j;
72+
} else {
73+
norm_index = j % CONFIG_T::n_filt;
74+
}
75+
out_data[j] = (in_data[j] >= threshold[norm_index]) ? 1 : 0;
7076
}
7177

7278
res.write(out_data);
@@ -92,7 +98,12 @@ void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<a
9298
for (int j = 0; j < data_T::size; j++) {
9399
#pragma HLS UNROLL
94100

95-
int norm_index = i * data_T::size + j;
101+
int norm_index;
102+
if (CONFIG_T::n_filt == -1) {
103+
norm_index = i * data_T::size + j;
104+
} else {
105+
norm_index = j % CONFIG_T::n_filt;
106+
}
96107

97108
if (in_data[j] > threshold_hi[norm_index]) {
98109
out_data[j] = 1;

hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ void conv_2d_resource_cl(
1616
constexpr unsigned mult_n_out = CONFIG_T::n_filt;
1717
constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor);
1818

19-
constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor);
20-
constexpr unsigned multscale = multiplier_limit / mult_n_out;
19+
constexpr unsigned multscale = block_factor / mult_n_out;
2120

22-
assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) &&
21+
assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) &&
2322
"The current Reuse Factor is not allowed");
24-
assert((multiplier_limit == block_factor) &&
23+
assert((CONFIG_T::reuse_factor <= CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan) &&
2524
"This function is correct only for RF <= FILT_HEIGHT * FILT_WIDTH * N_CHAN");
2625

2726
data_T data_buf[CONFIG_T::n_pixels][mult_n_in];

test/pytest/test_binary_cnn.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,33 @@
1212
test_root_path = Path(__file__).parent
1313

1414

15-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
16-
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
17-
def test_model2(backend, io_type):
15+
@pytest.mark.parametrize(
16+
'backend,io_type,strategy',
17+
[
18+
('Quartus', 'io_parallel', 'resource'),
19+
('Quartus', 'io_stream', 'resource'),
20+
('Vivado', 'io_parallel', 'resource'),
21+
('Vivado', 'io_parallel', 'latency'),
22+
('Vivado', 'io_stream', 'latency'),
23+
('Vivado', 'io_stream', 'resource'),
24+
('Vitis', 'io_parallel', 'resource'),
25+
('Vitis', 'io_parallel', 'latency'),
26+
('Vitis', 'io_stream', 'latency'),
27+
('Vitis', 'io_stream', 'resource'),
28+
],
29+
)
30+
def test_binary_cnn(backend, io_type, strategy):
1831
x_in = Input(shape=(28, 28, 1))
1932

20-
x = QConv2D(4, (3, 3), kernel_quantizer="binary", name="conv2d_1", kernel_regularizer=l2(0.0001), use_bias=False)(x_in)
33+
x = QConv2D(
34+
4,
35+
(3, 3),
36+
kernel_quantizer="binary",
37+
name="conv2d_1",
38+
kernel_regularizer=l2(0.0001),
39+
use_bias=True,
40+
bias_quantizer='quantized_bits(5,2)',
41+
)(x_in)
2142
x = QBatchNormalization()(x)
2243
x = QActivation("binary", name="act1")(x)
2344

@@ -37,28 +58,32 @@ def test_model2(backend, io_type):
3758
x = QBatchNormalization()(x)
3859
x = QActivation("binary_tanh", name="act4")(x)
3960

40-
x = QDense(10, kernel_quantizer="binary", activation="softmax", name="q_dense_7", use_bias=False)(x)
61+
x = QDense(10, kernel_quantizer="binary", activation="linear", name="q_dense_7", use_bias=False)(x)
4162

4263
model2 = Model(inputs=x_in, outputs=x)
4364

4465
model2.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
4566

4667
model2.summary()
4768

48-
hls_config = hls4ml.utils.config_from_keras_model(model2, granularity="name")
49-
hls_config["Model"]["Strategy"] = "Resource"
69+
hls_config = hls4ml.utils.config_from_keras_model(model2, granularity="name", default_precision='fixed<32,12>')
70+
hls_config["Model"]["Strategy"] = strategy
5071

51-
print(f"{hls_config['LayerName'].keys()=}")
52-
for layer in hls_config['LayerName'].keys():
53-
hls_config['LayerName'][layer]['Strategy'] = "Latency"
72+
# hls_config["LayerName"]["q_dense_7_softmax"]["Implementation"] = "legacy"
5473

55-
hls_config["LayerName"]["conv2d_1"]["ReuseFactor"] = 36
56-
hls_config["LayerName"]["conv2d_2"]["ReuseFactor"] = 288
57-
hls_config["LayerName"]["conv2d_3"]["ReuseFactor"] = 576
74+
hls_config["LayerName"]["conv2d_1"]["ReuseFactor"] = 9
75+
hls_config["LayerName"]["conv2d_2"]["ReuseFactor"] = 36
76+
hls_config["LayerName"]["conv2d_3"]["ReuseFactor"] = 72
5877
hls_config["LayerName"]["q_dense_6"]["ReuseFactor"] = 2000
5978
hls_config["LayerName"]["q_dense_7"]["ReuseFactor"] = 100
6079

61-
output_dir = str(test_root_path / f"hls4mlprj_binary_cnn_{backend}_{io_type}")
80+
if backend == 'Quartus' and io_type == 'io_parallel':
81+
# Winegrad imp[lementation does not support binary
82+
hls_config["LayerName"]["conv2d_1"]["Implementation"] = "im2col"
83+
hls_config["LayerName"]["conv2d_2"]["Implementation"] = "im2col"
84+
hls_config["LayerName"]["conv2d_3"]["Implementation"] = "im2col"
85+
86+
output_dir = str(test_root_path / f"hls4mlprj_binary_cnn_{backend}_{io_type}_{strategy}")
6287
hls_model = hls4ml.converters.convert_from_keras_model(
6388
model2,
6489
hls_config=hls_config,
@@ -67,11 +92,10 @@ def test_model2(backend, io_type):
6792
io_type=io_type,
6893
)
6994

70-
X = np.random.rand(1, 28, 28, 1)
95+
X = np.random.rand(100, 28, 28, 1)
7196

7297
hls_model.compile()
7398
y = model2.predict(X) # noqa: F841
7499
y_hls = hls_model.predict(X) # noqa: F841
75100

76-
# # TODO: enable the comparions after fixing the remaing issues
77-
# np.testing.assert_allclose(np.squeeze(y_hls), np.squeeze(y), rtol=1e-2, atol=0.01)
101+
np.testing.assert_allclose(y_hls, y, rtol=1e-2, atol=0.01)

0 commit comments

Comments
 (0)