From a98f3f98b12d96823262cc5bade6bd055fb6c601 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Thu, 7 Apr 2022 16:31:42 +0100 Subject: [PATCH 1/9] Custom multiplication for Quartus --- .../quartus/firmware/nnet_utils/nnet_dense.h | 55 ++------ .../nnet_utils/nnet_dense_compressed.h | 3 +- .../quartus/firmware/nnet_utils/nnet_mult.h | 117 ++++++++++++++++++ 3 files changed, 129 insertions(+), 46 deletions(-) create mode 100644 hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h index 8075a96714..b08eb309d2 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h @@ -21,6 +21,8 @@ #define NNET_DENSE_LARGE_H_ #include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" namespace nnet { @@ -49,50 +51,11 @@ struct dense_config static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? -}; - -inline ac_int<1, false> product(ac_int<1, false> a, ac_int<1, false> w) -{ - // specialisation for 1-bit weights and incoming data - return (a == w); -} - -template -auto product(data_T a, ac_int<1, false> w) -> decltype(-a) -{ - // Specialisation for 1-bit weights, arbitrary data - if (w == 0) return -a; - else return a; -} - -template -auto product(data_T a, ac_int<2, true> w) -> decltype(-a) -{ - // Specialisation for 2-bit weights, arbitrary data - if (w == 0) return 0; - else if(w == -1) return -a; - else return a; // if(w == 1) -} -template -auto product(data_T a, weight_T w) -> decltype(a*w) -{ - // 'Normal' product - return a * w; -} - -template -inline typename std::enable_if>::value - and std::is_same>::value, ac_int>::type -cast(typename CONFIG_T::accum_t x){ - return (ac_int) (x - CONFIG_T::n_in / 2) * 2; -} - -template -inline typename std::enable_if<(not std::is_same>::value), res_T>::type -cast(typename CONFIG_T::accum_t x){ - return (res_T) x; -} + // Default multiplication + template + using product = nnet::product::mult; +}; template void dense_rf_gt( @@ -134,7 +97,8 @@ void dense_rf_gt( uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (w_index >= CONFIG_T::reuse_factor_rounded*CONFIG_T::block_factor_rounded) continue; int data_index = d_index[ir][im]; - tmp_acc[im] = product(data[data_index], weights[w_index]); + // Modified this + tmp_acc[im] = CONFIG_T::template product::product(data[data_index], weights[w_index]); } hls_register typename CONFIG_T::accum_t mult[CONFIG_T::multiplier_limit]; ResetMult: @@ -188,7 +152,8 @@ void dense_rf_lt( for (int im = 0, in_index = ir; im < CONFIG_T::block_factor; im++) { uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (ir + CONFIG_T::reuse_factor * im >= CONFIG_T::n_in*CONFIG_T::n_out) continue; - mult[im] = product(data[in_index], weights[w_index]); + // Modified this + mult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); in_index += CONFIG_T::reuse_factor; if (in_index >= CONFIG_T::n_in) in_index = ir; } diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense_compressed.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense_compressed.h index 75fbfba22d..6ca9108cca 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense_compressed.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense_compressed.h @@ -63,7 +63,8 @@ void dense_compressed( for(int im = 0; im < CONFIG_T::compressed_block_factor; im++) { uint32 w = ir + CONFIG_T::reuse_factor * im; //if (w >= CONFIG_T::reuse_factor*CONFIG_T::compressed_block_factor) continue; - mult[im] = product(inputs[0][im], weights[w].weight); + typename CONFIG_T::accum_t prod = + mult[im] = CONFIG_T::template product::product(inputs[0][im], weights[w].weight); #pragma unroll for (int is = 0; is < CONFIG_T::reuse_factor-1; is++) { inputs[is][im] = inputs[is+1][im]; diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h new file mode 100644 index 0000000000..8af62f6edd --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_mult.h @@ -0,0 +1,117 @@ +#ifndef NNET_MULT_H_ +#define NNET_MULT_H_ + +#include "nnet_helpers.h" +#include "nnet_common.h" +#include + +namespace nnet { + + // Different methods to perform the product of input and weight, depending on their types. + namespace product{ + + class Product{ + public: + static void limit(unsigned multiplier_limit) {} + }; + + template + class both_binary : public Product{ + public: + inline static x_T product(x_T a, w_T w){ + // specialisation for 1-bit weights and incoming data + return a & w; + } + }; + + template + class weight_binary : public Product{ + public: + inline static auto product(x_T a, w_T w) -> decltype(-a) + { + // Specialisation for 1-bit weights, arbitrary data + if (w == 0) return -a; + else return a; + } + }; + + template + class data_binary : public Product{ + public: + inline static auto product(x_T a, w_T w) -> decltype(-w) + { + // Specialisation for 1-bit data, arbitrary weight + if (a == 0) return -w; + else return w; + } + }; + + template + class weight_ternary : public Product{ + public: + inline static auto product(x_T a, w_T w) -> decltype(-a) + { + // Specialisation for 2-bit weights, arbitrary data + if (w == 0) return 0; + else if(w == -1) return -a; + else return a; // if(w == 1) + } + }; + + template + class mult : public Product{ + public: + inline static auto product(x_T a, w_T w) -> decltype(a*w) + { + // 'Normal' product + return a * w; + } + static void limit(unsigned multiplier_limit){ + // TODO: Implement for Quartus + // #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation > Vivado-only, replace with Intel HLS pragma + } + }; + + template + class weight_exponential : public Product{ + public: + // Construct the return type from the multiplication equivalent to the largest shifts + // ac_int is the type if the multiplicand equivalent to the largest lshift << + // ac_fixed is the type of the multiplicand equivalent to the largest rshift >> + using r_T = decltype(x_T(0) * (ac_int(1)+ac_fixed(1))); + inline static r_T product(x_T a, w_T w){ + // Shift product for exponential weights + // shift by the exponent. Negative weights shift right + r_T y = static_cast(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast(-y); + } + }; + } // namespace product_type + + template + inline typename std::enable_if>::value + && std::is_same>::value, ac_int>::type + cast(typename CONFIG_T::accum_t x) + { + return (ac_int) (x - CONFIG_T::n_in / 2) * 2; + } + + template + inline typename std::enable_if>::value + && ! std::is_same>::value, res_T>::type + cast(typename CONFIG_T::accum_t x) + { + return (res_T) x; + } + + template + inline typename std::enable_if<(! std::is_same>::value), res_T>::type + cast(typename CONFIG_T::accum_t x) + { + return (res_T) x; + } + +} + +#endif From ff56b1dc8843b06ac0da57311d72b51a4f9b4cc8 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Thu, 7 Apr 2022 16:32:25 +0100 Subject: [PATCH 2/9] Support for different product types on Quartus --- hls4ml/backends/quartus/passes/core_templates.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py index 63c3693b0b..6f550f94d0 100644 --- a/hls4ml/backends/quartus/passes/core_templates.py +++ b/hls4ml/backends/quartus/passes/core_templates.py @@ -29,6 +29,9 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; + + template + using product = nnet::product::{product_type}; }};\n""" dense_function_template = 'nnet::dense_{strategy}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' @@ -44,7 +47,7 @@ def format(self, node): params = self._default_config_params(node) params['nzeros'] = node.get_weights('weight').nzeros params['nonzeros'] = node.get_weights('weight').nonzeros - params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) return self.template.format(**params) @@ -85,7 +88,7 @@ def __init__(self): def format(self, node): params = self._default_config_params(node) params['n_in'] = node.get_input_variable().size_cpp() - params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision) + params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision) return self.template.format(**params) From 95d4aaf1d87b84336945e865c231b4f9eae13fac Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 10:10:51 +0100 Subject: [PATCH 3/9] Quartus custom multiplication for batch normalisation --- hls4ml/backends/quartus/passes/core_templates.py | 2 ++ .../quartus/firmware/nnet_utils/nnet_batchnorm.h | 14 ++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py index 6f550f94d0..88fed63994 100644 --- a/hls4ml/backends/quartus/passes/core_templates.py +++ b/hls4ml/backends/quartus/passes/core_templates.py @@ -74,6 +74,8 @@ def format(self, node): static const bool store_weights_in_bram = false; typedef {bias_t.name} bias_t; typedef {scale_t.name} scale_t; + template + using product = nnet::product::{product_type}; }};\n""" batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h index c3d7065028..ab7d1f6ad8 100755 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h @@ -21,6 +21,8 @@ #define NNET_BATCHNORM_H_ #include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" namespace nnet { @@ -40,6 +42,10 @@ struct batchnorm_config static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? + + // Default multiplication + template + using product = nnet::product::mult; }; template @@ -54,12 +60,12 @@ void normalize( Result: #pragma unroll for (int ires = 0; ires < CONFIG_T::n_in; ires++) { + // TODO - Explore MULADD instruction in HLS - less clock cycles if (CONFIG_T::n_filt==-1) { - res[ires] = data[ires] * scale[ires] + bias[ires]; - } - else { + res[ires] = CONFIG_T::template product::product(data[ires], scale[ires]) + bias[ires]; + } else { int norm_index = ires%CONFIG_T::n_filt; - res[ires] = data[ires] * scale[norm_index] + bias[norm_index]; + res[ires] = CONFIG_T::template product::product(data[ires], scale[norm_index]) + bias[norm_index]; } } } From 8a763684b9e96a535672089bcfb5f7d6508b6878 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 10:12:05 +0100 Subject: [PATCH 4/9] Added quantization template for Quartus --- .../quartus/passes/quantization_templates.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 hls4ml/backends/quartus/passes/quantization_templates.py diff --git a/hls4ml/backends/quartus/passes/quantization_templates.py b/hls4ml/backends/quartus/passes/quantization_templates.py new file mode 100644 index 0000000000..746c93e7d7 --- /dev/null +++ b/hls4ml/backends/quartus/passes/quantization_templates.py @@ -0,0 +1,29 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.model.optimizer.passes.qkeras import ApplyAlpha +from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate +from hls4ml.backends.quartus.passes.core_templates import batchnorm_config_template, batchnorm_function_template, batchnorm_include_list + +class ApplyAlphaConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(ApplyAlpha) + self.template = batchnorm_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision) + + return self.template.format(**params) + +class ApplyAlphaFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(ApplyAlpha, include_header=batchnorm_include_list) + self.template = batchnorm_function_template + + def format(self, node): + params = self._default_function_params(node) + params['scale'] = node.get_weights('scale').name + params['bias'] = node.get_weights('bias').name + + return self.template.format(**params) + From fbe7cb81731c4bdbec885587d76744802ec13cff Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 14:12:14 +0100 Subject: [PATCH 5/9] Moved BatchNormalizationQuantizedTanh to common FPGA folder - used by both Quartus and Vivado --- hls4ml/backends/fpga/fpga_layers.py | 44 +++++++++++++++++++++++ hls4ml/backends/vivado/passes/bn_quant.py | 38 +------------------- 2 files changed, 45 insertions(+), 37 deletions(-) create mode 100644 hls4ml/backends/fpga/fpga_layers.py diff --git a/hls4ml/backends/fpga/fpga_layers.py b/hls4ml/backends/fpga/fpga_layers.py new file mode 100644 index 0000000000..5af9eca511 --- /dev/null +++ b/hls4ml/backends/fpga/fpga_layers.py @@ -0,0 +1,44 @@ +import numpy as np +import re + +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType +from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate + +class BatchNormalizationQuantizedTanh(Layer): + ''' Merged Batch Normalization and quantized (binary or ternary) Tanh layer. + The mean, variance, beta, gamma parameters are folded into the threshold(s) at which the + sign of the input flips after the quantized (binary or ternary) Tanh activation. + ''' + + def initialize(self): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + if self.get_attr('quantize') == 2: + self.add_output_variable(shape, dims, precision=XnorPrecisionType()) + elif self.get_attr('quantize') == 3: + self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2)) + else: + raise Exception('Unsupported quantize attribute for BatchNormalizationQuantizedTanh: {}'.format(self.get_attr('quantize'))) + + def set_thresholds(self, scale, bias, ternary_threshold=0.5): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + precision = self.model.config.backend.convert_precision_string(inp.type.precision) + W, I, F = precision.width, precision.integer, precision.fractional + threshold = - bias / scale + if self.get_attr('quantize') == 2: + self.add_output_variable(shape, dims, precision=XnorPrecisionType()) + threshold = np.floor(threshold * 2**F) / 2**F + self.add_weights_variable(name='threshold', var_name='t{index}', data=threshold, type_name='threshold{index}_t', precision=inp.type.precision) + elif self.get_attr('quantize') == 3: + self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2)) + threshold_hi = ternary_threshold / scale + threshold + threshold_lo = -ternary_threshold / scale + threshold + threshold_hi = np.floor(threshold_hi * 2**F) / 2**F + threshold_lo = np.floor(threshold_lo * 2**F) / 2**F + self.add_weights_variable(name='threshold_hi', var_name='th{index}', data=threshold_hi, type_name='threshold_hi_{index}_t', precision=inp.type.precision) + self.add_weights_variable(name='threshold_lo', var_name='tl{index}', data=threshold_lo, type_name='threshold_lo_{index}_t', precision=inp.type.precision) diff --git a/hls4ml/backends/vivado/passes/bn_quant.py b/hls4ml/backends/vivado/passes/bn_quant.py index ad4c5fa135..aebd4dee8f 100644 --- a/hls4ml/backends/vivado/passes/bn_quant.py +++ b/hls4ml/backends/vivado/passes/bn_quant.py @@ -5,43 +5,7 @@ from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate - -class BatchNormalizationQuantizedTanh(Layer): - ''' Merged Batch Normalization and quantized (binary or ternary) Tanh layer. - The mean, variance, beta, gamma parameters are folded into the threshold(s) at which the - sign of the input flips after the quantized (binary or ternary) Tanh activation. - ''' - - def initialize(self): - inp = self.get_input_variable() - shape = inp.shape - dims = inp.dim_names - if self.get_attr('quantize') == 2: - self.add_output_variable(shape, dims, precision=XnorPrecisionType()) - elif self.get_attr('quantize') == 3: - self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2)) - else: - raise Exception('Unsupported quantize attribute for BatchNormalizationQuantizedTanh: {}'.format(self.get_attr('quantize'))) - - def set_thresholds(self, scale, bias, ternary_threshold=0.5): - inp = self.get_input_variable() - shape = inp.shape - dims = inp.dim_names - precision = self.model.config.backend.convert_precision_string(inp.type.precision) - W, I, F = precision.width, precision.integer, precision.fractional - threshold = - bias / scale - if self.get_attr('quantize') == 2: - self.add_output_variable(shape, dims, precision=XnorPrecisionType()) - threshold = np.floor(threshold * 2**F) / 2**F - self.add_weights_variable(name='threshold', var_name='t{index}', data=threshold, type_name='threshold{index}_t', precision=inp.type.precision) - elif self.get_attr('quantize') == 3: - self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2)) - threshold_hi = ternary_threshold / scale + threshold - threshold_lo = -ternary_threshold / scale + threshold - threshold_hi = np.floor(threshold_hi * 2**F) / 2**F - threshold_lo = np.floor(threshold_lo * 2**F) / 2**F - self.add_weights_variable(name='threshold_hi', var_name='th{index}', data=threshold_hi, type_name='threshold_hi_{index}_t', precision=inp.type.precision) - self.add_weights_variable(name='threshold_lo', var_name='tl{index}', data=threshold_lo, type_name='threshold_lo_{index}_t', precision=inp.type.precision) +from hls4ml.backends.fpga.fpga_layers import BatchNormalizationQuantizedTanh batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{ static const unsigned n_in = {n_in}; From 49975f9f0774ec72e168788c972bd8dd6dbb31ac Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 14:12:36 +0100 Subject: [PATCH 6/9] Added quantization pass for Quartus --- hls4ml/backends/quartus/passes/bn_quant.py | 157 +++++++++++++++++++++ hls4ml/backends/quartus/quartus_backend.py | 10 +- 2 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 hls4ml/backends/quartus/passes/bn_quant.py diff --git a/hls4ml/backends/quartus/passes/bn_quant.py b/hls4ml/backends/quartus/passes/bn_quant.py new file mode 100644 index 0000000000..91b242fd23 --- /dev/null +++ b/hls4ml/backends/quartus/passes/bn_quant.py @@ -0,0 +1,157 @@ +import numpy as np + +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType +from hls4ml.model.layers import BatchNormalization, register_layer +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.backends.fpga.fpga_layers import BatchNormalizationQuantizedTanh + +batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_filt = {n_filt}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; +}};\n""" + +batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});' + +bn_include_list = ['nnet_utils/nnet_batchnorm.h'] + +class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(BatchNormalizationQuantizedTanh) + self.template = batchnorm_quantized_tanh_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + + return self.template.format(**params) + +class BatchNormalizationQuantizedTanhFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(BatchNormalizationQuantizedTanh, include_header=bn_include_list) + self.template = batchnorm_quantized_tanh_function_template + + def format(self, node): + params = self._default_function_params(node) + if node.get_attr('quantize') == 2: + params['quantize'] = 'binary' + params['threshold'] = node.get_weights('threshold').name + elif node.get_attr('quantize') == 3: + params['quantize'] = 'ternary' + params['threshold'] = node.get_weights('threshold_hi').name + ', ' + node.get_weights('threshold_lo').name + + return self.template.format(**params) + +def register_bn_quant(backend): + # Register the layer types to the layer map + register_layer('BatchNormalizationQuantizedTanh', BatchNormalizationQuantizedTanh) + + # Register the optimization passes + backend.register_pass('merge_batch_norm_quantized_tanh', MergeBatchNormAndQuantizedTanh) + backend.register_pass('quantize_dense_output', QuantizeDenseOutput) + + # Register template passes + backend.register_template(BatchNormalizationQuantizedTanhConfigTemplate) + backend.register_template(BatchNormalizationQuantizedTanhFunctionTemplate) + + +class MergeBatchNormAndQuantizedTanh(OptimizerPass): + def match(self, node): + is_match = (node.class_name == 'Activation' + and node.get_attr('activation') in ['binary', 'binary_tanh', 'ternary', 'ternary_tanh'] + or node.class_name == 'TernaryTanh') + is_match = is_match and isinstance(node.get_input_node(), BatchNormalization) + return is_match + + def transform(self, model, node): + bn_layer = node.get_input_node() + # Make a new layer with the new attributes + quantize = 0 + if 'binary' in node.get_attr('activation'): + quantize = 2 + if 'ternary' in node.get_attr('activation'): + quantize = 3 + attrs = { + 'name' : bn_layer.get_attr('name'), + 'original_name' : bn_layer.get_attr('name'), + 'class_name' : 'BatchNormalizationQuantizedTanh', + 'n_in' : bn_layer.get_attr('n_in'), + 'n_out' : bn_layer.get_attr('n_in'), + 'n_filt' : bn_layer.get_attr('n_filt'), + 'quantize' : quantize, + 'Trace' : bn_layer.get_attr('Trace') + } + bnbt_layer = model.make_node(BatchNormalizationQuantizedTanh, 'bnbt_' + bn_layer.name, attrs, bn_layer.inputs) + bnbt_layer.set_thresholds(bn_layer.get_weights('scale').data, bn_layer.get_weights('bias').data, node.get_attr('threshold',0.5)) + # Remove the BatchNormalization layer + model.remove_node(bn_layer, rewire=True) + # Replace the old Activation layer with this one + model.replace_node(node, bnbt_layer) + + return True + +class QuantizeDenseOutput(OptimizerPass): + def match(self, node): + is_dense = node.class_name == 'Dense' + input_node = node.get_input_node() + is_input_bnqt = input_node is not None and input_node.class_name == 'BatchNormalizationQuantizedTanh' + quantizer = node.get_attr('weight_quantizer') + is_binary_ternary = quantizer is not None and (quantizer.__class__.__name__ == 'BinaryQuantizer' or quantizer.__class__.__name__ == 'TernaryQuantizer') + return is_dense and is_input_bnqt and is_binary_ternary + + def transform(self, model, node): + # Compute the required precision and update the variables + # Number of bits for output is log2 of number of input nodes + # Since this is the number of uint<1>'s which are summed + nbits = int(np.ceil(np.log2(node.attributes['n_in'])) + 2) + out_type = IntegerPrecisionType(width=nbits) + accum_t = NamedType('layer{}_accum_t'.format(node.index), out_type) + node.set_attr('accum_t', accum_t) + out_var = node.get_output_variable() + out_var.type.precision = out_type + + quantized_data = None + quantized_precision = None + quantizer = node.get_attr('weight_quantizer') + if quantizer.__class__.__name__ == 'BinaryQuantizer': + quantized_precision = XnorPrecisionType() + elif quantizer.__class__.__name__ == 'TernaryQuantizer': + quantized_precision = IntegerPrecisionType(width=2) + else: + print('WARNING: Unknown quantizer - {}. Bailing out'.format(quantizer.__class__.__name__)) + return False + quantizer.bits = quantized_precision.width + quantizer.hls_type = quantized_precision + quantized_data = quantizer(node.weights['weight'].data) + + weights = node.weights['weight'] + weights.data = quantized_data + weights.type.name = 'weight{index}_t'.format(index=node.index) + weights.update_precision(quantized_precision) + + bias = node.weights['bias'] + bias.data = np.zeros(shape=(node.get_attr('n_out'))) + bias.type.name = 'bias{index}_t'.format(index=node.index) + bias.nzeros = 0 + bias.update_precision(quantized_precision) + + # If followed by the BatchNormalizationBinaryTanh, update its input + # Also requantise the weights + bd_out_nodes = node.get_output_nodes() + for out_node in bd_out_nodes: + if isinstance(out_node, BatchNormalizationQuantizedTanh): + var_names = [] + if quantizer.__class__.__name__ == 'BinaryQuantizer': + var_names.append('threshold') + elif quantizer.__class__.__name__ == 'TernaryQuantizer': + var_names.append('threshold_hi') + var_names.append('threshold_lo') + for var_name in var_names: + threshold_var = out_node.weights[var_name] + threshold_var.update_precision(out_type) + threshold_var.data = np.floor(threshold_var.data) + + return False + diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index 9cb4f44bf4..8f174f9043 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -39,6 +39,14 @@ def _register_flows(self): ] quartus_types_flow = register_flow('specific_types', quartus_types, requires=[init_flow], backend=self.name) + quantization_passes = [ + 'quartus:merge_batch_norm_quantized_tanh', + 'quartus:quantize_dense_output', + 'fuse_consecutive_batch_normalization', + ] + quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name) + + templates = self._get_layer_templates() template_flow = register_flow('apply_templates', templates, requires=[init_flow], backend=self.name) @@ -60,7 +68,7 @@ def _register_flows(self): else: extras_flow = None - ip_flow_requirements = ['optimize', init_flow, quartus_types_flow, extras_flow, template_flow] + ip_flow_requirements = ['optimize', init_flow, quantization_flow, quartus_types_flow, extras_flow, template_flow] ip_flow_requirements = list(filter(None, ip_flow_requirements)) self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name) From df206d23305f0cffd3c6e3148c6bfdf7370b35fc Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 14:35:04 +0100 Subject: [PATCH 7/9] Parameterized qkeras tests for Quartus --- test/pytest/test_qkeras.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index 1b16ae21a2..b38e0967b3 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -50,6 +50,7 @@ def load_jettagging_model(): model.load_weights(example_model_path / 'keras/qkeras_3layer_weights.h5') return model +# TODO - Paramaterize for Quartus (different strategies?) @pytest.fixture @pytest.mark.parametrize('strategy', ['latency', 'resource']) def convert(load_jettagging_model, strategy): @@ -111,7 +112,8 @@ def randX_100_16(): # https://github.com/fastmachinelearning/hls4ml/issues/381 #@pytest.mark.parametrize('bits', [4, 6, 8]) @pytest.mark.parametrize('bits,alpha', [(4, 1), (4, 'auto_po2')]) -def test_single_dense_activation_exact(randX_100_16, bits, alpha): +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def test_single_dense_activation_exact(randX_100_16, bits, alpha, backend): ''' Test a single Dense -> Activation layer topology for bit exactness with number of bits parameter @@ -126,10 +128,11 @@ def test_single_dense_activation_exact(randX_100_16, bits, alpha): hls4ml.model.optimizer.get_optimizer('output_rounding_saturation_mode').configure(layers=['relu1'], rounding_mode='AP_RND_CONV', saturation_mode='AP_SAT') config = hls4ml.utils.config_from_keras_model(model, granularity='name') + output_dir = str(test_root_path / 'hls4mlprj_qkeras_single_dense_activation_exact_{}_{}_{}'.format(bits, alpha, backend)) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, - output_dir=str(test_root_path / 'hls4mlprj_qkeras_single_dense_activation_exact_{}_{}'.format(bits, alpha)), - part='xcu250-figd2104-2L-e') + output_dir=output_dir, + backend=backend) hls4ml.model.optimizer.get_optimizer('output_rounding_saturation_mode').configure(layers=[]) hls_model.compile() @@ -164,11 +167,13 @@ def randX_100_10(): (5, 10, ternary(alpha='auto'), quantized_bits(5,2), ternary(threshold=0.2), True, False), (6, 10, ternary(alpha='auto'), quantized_bits(5,2), ternary(threshold=0.8), True, False), (7, 10, binary(), quantized_bits(5,2), binary(), False, True)]) -def test_btnn(make_btnn, randX_100_10): +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def test_btnn(make_btnn, randX_100_10, backend): model, is_xnor, test_no = make_btnn X = randX_100_10 cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') - hls_model = hls4ml.converters.convert_from_keras_model(model, output_dir=str(test_root_path / 'hls4mlprj_btnn_{}'.format(test_no)), hls_config=cfg) + output_dir = str(test_root_path / 'hls4mlprj_btnn_{}_{}'.format(test_no, backend)) + hls_model = hls4ml.converters.convert_from_keras_model(model, output_dir=output_dir, hls_config=cfg, backend=backend) hls_model.compile() y_hls = hls_model.predict(X) # hls4ml may return XNOR binary @@ -195,7 +200,8 @@ def randX_1000_1(): (quantized_relu(8,4)), (quantized_relu(10)), (quantized_relu(10,5))]) -def test_quantizer(randX_1000_1, quantizer): +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def test_quantizer(randX_1000_1, quantizer, backend): ''' Test a single quantizer as an Activation function. Checks the type inference through the conversion is correct without just @@ -209,12 +215,12 @@ def test_quantizer(randX_1000_1, quantizer): hls4ml.model.optimizer.get_optimizer('output_rounding_saturation_mode').configure(layers=['quantizer'], rounding_mode='AP_RND_CONV', saturation_mode='AP_SAT') config = hls4ml.utils.config_from_keras_model(model, granularity='name') - output_dir = str(test_root_path / 'hls4mlprj_qkeras_quantizer_{}_{}_{}'.format(quantizer.__class__.__name__, - quantizer.bits, quantizer.integer)) + output_dir = str(test_root_path / 'hls4mlprj_qkeras_quantizer_{}_{}_{}_{}'.format(quantizer.__class__.__name__, + quantizer.bits, quantizer.integer, backend)) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, - part='xcu250-figd2104-2L-e') + backend=backend) hls4ml.model.optimizer.get_optimizer('output_rounding_saturation_mode').configure(layers=[]) hls_model.compile() From 0b7eb3838b2d076c05760435d7a0fe5cf63aeee3 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 8 Apr 2022 15:34:02 +0100 Subject: [PATCH 8/9] Quartus test for Softmax --- test/pytest/test_softmax.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index bfd7ed65a9..44bfb9dc63 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -25,31 +25,39 @@ def generate_data(function, input_shape): # TODO: include latency strategy with flat_distribution when it can be made to pass -@pytest.mark.parametrize('strategy,function,input_shape,io_type', [#('latency', flat_distribution, (8,), 'io_parallel'), - #('latency', flat_distribution, (8, 8, 3), 'io_stream'), - ('stable', flat_distribution, (8,), 'io_parallel'), - ('stable', high_accuracy_distribution, (8,), 'io_parallel'), - ('stable', flat_distribution, (8,), 'io_stream'), - ('stable', high_accuracy_distribution, (8,), 'io_stream'), - # Multi-dimensional tests, only for io_stream for now - ('stable', flat_distribution, (8, 8, 3), 'io_stream'), - ('stable', high_accuracy_distribution, (8, 8, 3), 'io_stream')]) -def test_softmax(strategy, generate_data, input_shape, io_type): +@pytest.mark.parametrize('backend,strategy,function,input_shape,io_type', [ + #('latency', flat_distribution, (8,), 'io_parallel'), + #('latency', flat_distribution, (8, 8, 3), 'io_stream'), + ('Vivado', 'stable', flat_distribution, (8,), 'io_parallel'), + ('Vivado', 'stable', high_accuracy_distribution, (8,), 'io_parallel'), + ('Quartus', 'resource', flat_distribution, (8,), 'io_parallel'), + ('Quartus', 'resource', high_accuracy_distribution, (8,), 'io_parallel'), + ('Vivado', 'stable', flat_distribution, (8,), 'io_stream'), + ('Vivado', 'stable', high_accuracy_distribution, (8,), 'io_stream'), + # Multi-dimensional tests, only for io_stream for now + ('Vivado', 'stable', flat_distribution, (8, 8, 3), 'io_stream'), + ('Vivado', 'stable', high_accuracy_distribution, (8, 8, 3), 'io_stream') + + ]) +def test_softmax(backend, strategy, generate_data, input_shape, io_type): X = generate_data model = tf.keras.models.Sequential() model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) model.compile() + + f_type = 'ac_fixed<18,8,true,AC_RND,AC_SAT>' if backend == 'Quartus' else 'ap_fixed<18,8,AP_RND,AP_SAT>' cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') cfg['LayerName']['softmax']['Strategy'] = strategy - cfg['LayerName']['softmax']['inv_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' - cfg['LayerName']['softmax']['exp_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' + cfg['LayerName']['softmax']['inv_table_t'] = f_type + cfg['LayerName']['softmax']['exp_table_t'] = f_type + odir = str(test_root_path / 'hls4mlprj_softmax_{}'.format(strategy)) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, io_type=io_type, - output_dir=odir) + output_dir=odir, backend=backend) hls_model.compile() + y_keras = model.predict(X) y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) - acc_hls4ml = accuracy_score(np.argmax(y_keras, axis=-1).ravel(), np.argmax(y_hls4ml, axis=-1).ravel()) print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml)) From c7a8888d51938fc70361d9c2f1ae06abaa39c8a0 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Thu, 21 Apr 2022 19:20:39 +0100 Subject: [PATCH 9/9] Common FPGA passes, for both Quartus and Vivado --- hls4ml/backends/backend.py | 5 + hls4ml/backends/fpga/passes/__init__.py | 0 .../{quartus => fpga}/passes/bn_quant.py | 2 +- hls4ml/backends/vivado/passes/bn_quant.py | 157 ------------------ .../nnet_utils/nnet_batchnorm_stream.h | 33 ++++ 5 files changed, 39 insertions(+), 158 deletions(-) create mode 100644 hls4ml/backends/fpga/passes/__init__.py rename hls4ml/backends/{quartus => fpga}/passes/bn_quant.py (98%) delete mode 100644 hls4ml/backends/vivado/passes/bn_quant.py create mode 100644 hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h diff --git a/hls4ml/backends/backend.py b/hls4ml/backends/backend.py index 05385617e9..b121044629 100644 --- a/hls4ml/backends/backend.py +++ b/hls4ml/backends/backend.py @@ -26,6 +26,11 @@ def _init_file_optimizers(self): opt_path = os.path.dirname(inspect.getfile(self.__class__)) + '/passes' module_path = self.__module__[:self.__module__.rfind('.')] + '.passes' file_optimizers = extract_optimizers_from_path(opt_path, module_path, self) + for base in self.__class__.__bases__: + opt_path = os.path.dirname(inspect.getfile(base)) + '/passes' + module_path = base.__module__[:base.__module__.rfind('.')] + '.passes' + base_optimizers = extract_optimizers_from_path(opt_path, module_path, self) + file_optimizers.update(base_optimizers) return file_optimizers def _get_layer_initializers(self): diff --git a/hls4ml/backends/fpga/passes/__init__.py b/hls4ml/backends/fpga/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/quartus/passes/bn_quant.py b/hls4ml/backends/fpga/passes/bn_quant.py similarity index 98% rename from hls4ml/backends/quartus/passes/bn_quant.py rename to hls4ml/backends/fpga/passes/bn_quant.py index 91b242fd23..b51d7610f1 100644 --- a/hls4ml/backends/quartus/passes/bn_quant.py +++ b/hls4ml/backends/fpga/passes/bn_quant.py @@ -15,7 +15,7 @@ batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});' -bn_include_list = ['nnet_utils/nnet_batchnorm.h'] +bn_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate): def __init__(self): diff --git a/hls4ml/backends/vivado/passes/bn_quant.py b/hls4ml/backends/vivado/passes/bn_quant.py deleted file mode 100644 index aebd4dee8f..0000000000 --- a/hls4ml/backends/vivado/passes/bn_quant.py +++ /dev/null @@ -1,157 +0,0 @@ -import numpy as np -import re - -from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType -from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer -from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.backends.fpga.fpga_layers import BatchNormalizationQuantizedTanh - -batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{ - static const unsigned n_in = {n_in}; - static const unsigned n_filt = {n_filt}; - static const unsigned io_type = nnet::{iotype}; - static const unsigned reuse_factor = {reuse}; -}};\n""" - -batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});' -bn_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] - -class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate): - def __init__(self): - super().__init__(BatchNormalizationQuantizedTanh) - self.template = batchnorm_quantized_tanh_config_template - - def format(self, node): - params = self._default_config_params(node) - params['n_in'] = node.get_input_variable().size_cpp() - - return self.template.format(**params) - -class BatchNormalizationQuantizedTanhFunctionTemplate(FunctionCallTemplate): - def __init__(self): - super().__init__(BatchNormalizationQuantizedTanh, include_header=bn_include_list) - self.template = batchnorm_quantized_tanh_function_template - - def format(self, node): - params = self._default_function_params(node) - if node.get_attr('quantize') == 2: - params['quantize'] = 'binary' - params['threshold'] = node.get_weights('threshold').name - elif node.get_attr('quantize') == 3: - params['quantize'] = 'ternary' - params['threshold'] = node.get_weights('threshold_hi').name + ', ' + node.get_weights('threshold_lo').name - - return self.template.format(**params) - -def register_bn_quant(backend): - # Register the layer types to the layer map - register_layer('BatchNormalizationQuantizedTanh', BatchNormalizationQuantizedTanh) - - # Register the optimization passes - backend.register_pass('merge_batch_norm_quantized_tanh', MergeBatchNormAndQuantizedTanh) - backend.register_pass('quantize_dense_output', QuantizeDenseOutput) - - # Register template passes - backend.register_template(BatchNormalizationQuantizedTanhConfigTemplate) - backend.register_template(BatchNormalizationQuantizedTanhFunctionTemplate) - - -class MergeBatchNormAndQuantizedTanh(OptimizerPass): - def match(self, node): - is_match = (node.class_name == 'Activation' - and node.get_attr('activation') in ['binary', 'binary_tanh', 'ternary', 'ternary_tanh'] - or node.class_name == 'TernaryTanh') - is_match = is_match and isinstance(node.get_input_node(), BatchNormalization) - return is_match - - def transform(self, model, node): - bn_layer = node.get_input_node() - # Make a new layer with the new attributes - quantize = 0 - if 'binary' in node.get_attr('activation'): - quantize = 2 - if 'ternary' in node.get_attr('activation'): - quantize = 3 - attrs = { - 'name' : bn_layer.get_attr('name'), - 'original_name' : bn_layer.get_attr('name'), - 'class_name' : 'BatchNormalizationQuantizedTanh', - 'n_in' : bn_layer.get_attr('n_in'), - 'n_out' : bn_layer.get_attr('n_in'), - 'n_filt' : bn_layer.get_attr('n_filt'), - 'quantize' : quantize, - 'Trace' : bn_layer.get_attr('Trace') - } - bnbt_layer = model.make_node(BatchNormalizationQuantizedTanh, 'bnbt_' + bn_layer.name, attrs, bn_layer.inputs) - bnbt_layer.set_thresholds(bn_layer.get_weights('scale').data, bn_layer.get_weights('bias').data, node.get_attr('threshold',0.5)) - # Remove the BatchNormalization layer - model.remove_node(bn_layer, rewire=True) - # Replace the old Activation layer with this one - model.replace_node(node, bnbt_layer) - - return True - -class QuantizeDenseOutput(OptimizerPass): - def match(self, node): - is_dense = node.class_name == 'Dense' - input_node = node.get_input_node() - is_input_bnqt = input_node is not None and input_node.class_name == 'BatchNormalizationQuantizedTanh' - quantizer = node.get_attr('weight_quantizer') - is_binary_ternary = quantizer is not None and (quantizer.__class__.__name__ == 'BinaryQuantizer' or quantizer.__class__.__name__ == 'TernaryQuantizer') - return is_dense and is_input_bnqt and is_binary_ternary - - def transform(self, model, node): - # Compute the required precision and update the variables - # Number of bits for output is log2 of number of input nodes - # Since this is the number of uint<1>'s which are summed - nbits = int(np.ceil(np.log2(node.attributes['n_in'])) + 2) - out_type = IntegerPrecisionType(width=nbits) - accum_t = NamedType('layer{}_accum_t'.format(node.index), out_type) - node.set_attr('accum_t', accum_t) - out_var = node.get_output_variable() - out_var.type.precision = out_type - - quantized_data = None - quantized_precision = None - quantizer = node.get_attr('weight_quantizer') - if quantizer.__class__.__name__ == 'BinaryQuantizer': - quantized_precision = XnorPrecisionType() - elif quantizer.__class__.__name__ == 'TernaryQuantizer': - quantized_precision = IntegerPrecisionType(width=2) - else: - print('WARNING: Unknown quantizer - {}. Bailing out'.format(quantizer.__class__.__name__)) - return False - quantizer.bits = quantized_precision.width - quantizer.hls_type = quantized_precision - quantized_data = quantizer(node.weights['weight'].data) - - weights = node.weights['weight'] - weights.data = quantized_data - weights.type.name = 'weight{index}_t'.format(index=node.index) - weights.update_precision(quantized_precision) - - bias = node.weights['bias'] - bias.data = np.zeros(shape=(node.get_attr('n_out'))) - bias.type.name = 'bias{index}_t'.format(index=node.index) - bias.nzeros = 0 - bias.update_precision(quantized_precision) - - # If followed by the BatchNormalizationBinaryTanh, update its input - # Also requantise the weights - bd_out_nodes = node.get_output_nodes() - for out_node in bd_out_nodes: - if isinstance(out_node, BatchNormalizationQuantizedTanh): - var_names = [] - if quantizer.__class__.__name__ == 'BinaryQuantizer': - var_names.append('threshold') - elif quantizer.__class__.__name__ == 'TernaryQuantizer': - var_names.append('threshold_hi') - var_names.append('threshold_lo') - for var_name in var_names: - threshold_var = out_node.weights[var_name] - threshold_var.update_precision(out_type) - threshold_var.data = np.floor(threshold_var.data) - - return False - diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h new file mode 100644 index 0000000000..a87d813151 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm_stream.h @@ -0,0 +1,33 @@ +// +// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks +// +// Copyright (C) 2017 EJ Kreinar +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +/* +* PLACEHOLDER - The common pass bn_quant.py includes both parallel and streaming BN; streaming is currently not supported in Quartus +*/ + +#ifndef NNET_BATCHNORM_STREAM_H_ +#define NNET_BATCHNORM_STREAM_H_ + +#include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" + +namespace nnet {} + +#endif