Skip to content

Make binary CNN match between Keras and hls4ml #804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions hls4ml/backends/fpga/passes/xnor_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from hls4ml.model.layers import GlobalPooling1D, GlobalPooling2D, Pooling1D, Pooling2D
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import XnorPrecisionType


class XnorPooling(OptimizerPass):
'''
For correct behavior, for MaxPooling and similar, for XnorPrecisionType, have to propagate
the type to the output.
'''

def match(self, node):
if isinstance(node, (Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D)) and node.get_attr('pool_op') == 'Max':
return isinstance(node.get_input_variable().type.precision, XnorPrecisionType) and not isinstance(
node.get_output_variable().type.precision, XnorPrecisionType
)
return False

def transform(self, model, node):
outvar = node.get_output_variable()
outvar.type.precision = XnorPrecisionType()
return True
1 change: 1 addition & 0 deletions hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _register_flows(self):
'quartus:merge_batch_norm_quantized_tanh',
'quartus:quantize_dense_output',
'fuse_consecutive_batch_normalization',
'quartus:xnor_pooling',
]
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)

Expand Down
1 change: 1 addition & 0 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _register_flows(self):
'vivado:merge_batch_norm_quantized_tanh',
'vivado:quantize_dense_output',
'fuse_consecutive_batch_normalization',
'vivado:xnor_pooling',
]
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ac_int<1, false> res[CON
ac_int<1, false> cache;
data_T datareg = data[ii];
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
if (datareg > threshold[norm_index])
if (datareg >= threshold[norm_index])
cache = 1;
else
cache = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ void normalize_binary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<1, fa
BatchNormPack:
#pragma unroll
for (int j = 0; j < data_T::size; j++) {
out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0;
int norm_index;
if (CONFIG_T::n_filt == -1)
norm_index = i * data_T::size + j;
else
norm_index = j % CONFIG_T::n_filt;

out_data[j] = (in_data[j] >= threshold[norm_index]) ? 1 : 0;
}

res.write(out_data);
Expand All @@ -79,7 +85,12 @@ void normalize_ternary_tanh(stream<data_T> &data, stream<nnet::array<ac_int<2, t
BatchNormPack:
#pragma unroll
for (int j = 0; j < data_T::size; j++) {
int norm_index = i * data_T::size + j;
int norm_index;
if (CONFIG_T::n_filt == -1)
norm_index = i * data_T::size + j;
else
norm_index = j % CONFIG_T::n_filt;

if (in_data[j] > threshold_hi[norm_index])
out_data[j] = 1;
else if (in_data[j] <= threshold_lo[norm_index])
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ template <class x_T, class w_T> class both_binary : public Product {
public:
inline static x_T product(x_T a, w_T w) {
// specialisation for 1-bit weights and incoming data
return a & w;
return a == w;
}
};

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ap_uint<1> res[CONFIG_T:
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
datareg = data[ii];
int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt;
if (datareg > threshold[norm_index])
if (datareg >= threshold[norm_index])
cache = 1;
else
cache = 0;
Expand Down
15 changes: 13 additions & 2 deletions hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ void normalize_binary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<ap
BatchNormPack:
for (int j = 0; j < data_T::size; j++) {
#pragma HLS UNROLL
out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0;
int norm_index;
if (CONFIG_T::n_filt == -1) {
norm_index = i * data_T::size + j;
} else {
norm_index = j % CONFIG_T::n_filt;
}
out_data[j] = (in_data[j] >= threshold[norm_index]) ? 1 : 0;
}

res.write(out_data);
Expand All @@ -92,7 +98,12 @@ void normalize_ternary_tanh(hls::stream<data_T> &data, hls::stream<nnet::array<a
for (int j = 0; j < data_T::size; j++) {
#pragma HLS UNROLL

int norm_index = i * data_T::size + j;
int norm_index;
if (CONFIG_T::n_filt == -1) {
norm_index = i * data_T::size + j;
} else {
norm_index = j % CONFIG_T::n_filt;
}

if (in_data[j] > threshold_hi[norm_index]) {
out_data[j] = 1;
Expand Down
7 changes: 3 additions & 4 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ void conv_2d_resource_cl(
constexpr unsigned mult_n_out = CONFIG_T::n_filt;
constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor);

constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor);
constexpr unsigned multscale = multiplier_limit / mult_n_out;
constexpr unsigned multscale = block_factor / mult_n_out;

assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) &&
assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) &&
"The current Reuse Factor is not allowed");
assert((multiplier_limit == block_factor) &&
assert((CONFIG_T::reuse_factor <= CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan) &&
"This function is correct only for RF <= FILT_HEIGHT * FILT_WIDTH * N_CHAN");

data_T data_buf[CONFIG_T::n_pixels][mult_n_in];
Expand Down
58 changes: 41 additions & 17 deletions test/pytest/test_binary_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,33 @@
test_root_path = Path(__file__).parent


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_model2(backend, io_type):
@pytest.mark.parametrize(
'backend,io_type,strategy',
[
('Quartus', 'io_parallel', 'resource'),
('Quartus', 'io_stream', 'resource'),
('Vivado', 'io_parallel', 'resource'),
('Vivado', 'io_parallel', 'latency'),
('Vivado', 'io_stream', 'latency'),
('Vivado', 'io_stream', 'resource'),
('Vitis', 'io_parallel', 'resource'),
('Vitis', 'io_parallel', 'latency'),
('Vitis', 'io_stream', 'latency'),
('Vitis', 'io_stream', 'resource'),
],
)
def test_binary_cnn(backend, io_type, strategy):
x_in = Input(shape=(28, 28, 1))

x = QConv2D(4, (3, 3), kernel_quantizer="binary", name="conv2d_1", kernel_regularizer=l2(0.0001), use_bias=False)(x_in)
x = QConv2D(
4,
(3, 3),
kernel_quantizer="binary",
name="conv2d_1",
kernel_regularizer=l2(0.0001),
use_bias=True,
bias_quantizer='quantized_bits(5,2)',
)(x_in)
x = QBatchNormalization()(x)
x = QActivation("binary", name="act1")(x)

Expand All @@ -37,28 +58,32 @@ def test_model2(backend, io_type):
x = QBatchNormalization()(x)
x = QActivation("binary_tanh", name="act4")(x)

x = QDense(10, kernel_quantizer="binary", activation="softmax", name="q_dense_7", use_bias=False)(x)
x = QDense(10, kernel_quantizer="binary", activation="linear", name="q_dense_7", use_bias=False)(x)

model2 = Model(inputs=x_in, outputs=x)

model2.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

model2.summary()

hls_config = hls4ml.utils.config_from_keras_model(model2, granularity="name")
hls_config["Model"]["Strategy"] = "Resource"
hls_config = hls4ml.utils.config_from_keras_model(model2, granularity="name", default_precision='fixed<32,12>')
hls_config["Model"]["Strategy"] = strategy

print(f"{hls_config['LayerName'].keys()=}")
for layer in hls_config['LayerName'].keys():
hls_config['LayerName'][layer]['Strategy'] = "Latency"
# hls_config["LayerName"]["q_dense_7_softmax"]["Implementation"] = "legacy"

hls_config["LayerName"]["conv2d_1"]["ReuseFactor"] = 36
hls_config["LayerName"]["conv2d_2"]["ReuseFactor"] = 288
hls_config["LayerName"]["conv2d_3"]["ReuseFactor"] = 576
hls_config["LayerName"]["conv2d_1"]["ReuseFactor"] = 9
hls_config["LayerName"]["conv2d_2"]["ReuseFactor"] = 36
hls_config["LayerName"]["conv2d_3"]["ReuseFactor"] = 72
hls_config["LayerName"]["q_dense_6"]["ReuseFactor"] = 2000
hls_config["LayerName"]["q_dense_7"]["ReuseFactor"] = 100

output_dir = str(test_root_path / f"hls4mlprj_binary_cnn_{backend}_{io_type}")
if backend == 'Quartus' and io_type == 'io_parallel':
# Winegrad imp[lementation does not support binary
hls_config["LayerName"]["conv2d_1"]["Implementation"] = "im2col"
hls_config["LayerName"]["conv2d_2"]["Implementation"] = "im2col"
hls_config["LayerName"]["conv2d_3"]["Implementation"] = "im2col"

output_dir = str(test_root_path / f"hls4mlprj_binary_cnn_{backend}_{io_type}_{strategy}")
hls_model = hls4ml.converters.convert_from_keras_model(
model2,
hls_config=hls_config,
Expand All @@ -67,11 +92,10 @@ def test_model2(backend, io_type):
io_type=io_type,
)

X = np.random.rand(1, 28, 28, 1)
X = np.random.rand(100, 28, 28, 1)

hls_model.compile()
y = model2.predict(X) # noqa: F841
y_hls = hls_model.predict(X) # noqa: F841

# # TODO: enable the comparions after fixing the remaing issues
# np.testing.assert_allclose(np.squeeze(y_hls), np.squeeze(y), rtol=1e-2, atol=0.01)
np.testing.assert_allclose(y_hls, y, rtol=1e-2, atol=0.01)