Skip to content

Quartus Custom Matrix Multiplication & Quantization #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hls4ml/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def _init_file_optimizers(self):
opt_path = os.path.dirname(inspect.getfile(self.__class__)) + '/passes'
module_path = self.__module__[:self.__module__.rfind('.')] + '.passes'
file_optimizers = extract_optimizers_from_path(opt_path, module_path, self)
for base in self.__class__.__bases__:
opt_path = os.path.dirname(inspect.getfile(base)) + '/passes'
module_path = base.__module__[:base.__module__.rfind('.')] + '.passes'
base_optimizers = extract_optimizers_from_path(opt_path, module_path, self)
file_optimizers.update(base_optimizers)
return file_optimizers

def _get_layer_initializers(self):
Expand Down
44 changes: 44 additions & 0 deletions hls4ml/backends/fpga/fpga_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import re

from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType
from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate

class BatchNormalizationQuantizedTanh(Layer):
''' Merged Batch Normalization and quantized (binary or ternary) Tanh layer.
The mean, variance, beta, gamma parameters are folded into the threshold(s) at which the
sign of the input flips after the quantized (binary or ternary) Tanh activation.
'''

def initialize(self):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
if self.get_attr('quantize') == 2:
self.add_output_variable(shape, dims, precision=XnorPrecisionType())
elif self.get_attr('quantize') == 3:
self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2))
else:
raise Exception('Unsupported quantize attribute for BatchNormalizationQuantizedTanh: {}'.format(self.get_attr('quantize')))

def set_thresholds(self, scale, bias, ternary_threshold=0.5):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
precision = self.model.config.backend.convert_precision_string(inp.type.precision)
W, I, F = precision.width, precision.integer, precision.fractional
threshold = - bias / scale
if self.get_attr('quantize') == 2:
self.add_output_variable(shape, dims, precision=XnorPrecisionType())
threshold = np.floor(threshold * 2**F) / 2**F
self.add_weights_variable(name='threshold', var_name='t{index}', data=threshold, type_name='threshold{index}_t', precision=inp.type.precision)
elif self.get_attr('quantize') == 3:
self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2))
threshold_hi = ternary_threshold / scale + threshold
threshold_lo = -ternary_threshold / scale + threshold
threshold_hi = np.floor(threshold_hi * 2**F) / 2**F
threshold_lo = np.floor(threshold_lo * 2**F) / 2**F
self.add_weights_variable(name='threshold_hi', var_name='th{index}', data=threshold_hi, type_name='threshold_hi_{index}_t', precision=inp.type.precision)
self.add_weights_variable(name='threshold_lo', var_name='tl{index}', data=threshold_lo, type_name='threshold_lo_{index}_t', precision=inp.type.precision)
Empty file.
Original file line number Diff line number Diff line change
@@ -1,47 +1,10 @@
import numpy as np
import re

from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import IntegerPrecisionType, NamedType, XnorPrecisionType
from hls4ml.model.layers import Layer, Activation, Dense, BatchNormalization, register_layer
from hls4ml.model.layers import BatchNormalization, register_layer
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate

class BatchNormalizationQuantizedTanh(Layer):
''' Merged Batch Normalization and quantized (binary or ternary) Tanh layer.
The mean, variance, beta, gamma parameters are folded into the threshold(s) at which the
sign of the input flips after the quantized (binary or ternary) Tanh activation.
'''

def initialize(self):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
if self.get_attr('quantize') == 2:
self.add_output_variable(shape, dims, precision=XnorPrecisionType())
elif self.get_attr('quantize') == 3:
self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2))
else:
raise Exception('Unsupported quantize attribute for BatchNormalizationQuantizedTanh: {}'.format(self.get_attr('quantize')))

def set_thresholds(self, scale, bias, ternary_threshold=0.5):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
precision = self.model.config.backend.convert_precision_string(inp.type.precision)
W, I, F = precision.width, precision.integer, precision.fractional
threshold = - bias / scale
if self.get_attr('quantize') == 2:
self.add_output_variable(shape, dims, precision=XnorPrecisionType())
threshold = np.floor(threshold * 2**F) / 2**F
self.add_weights_variable(name='threshold', var_name='t{index}', data=threshold, type_name='threshold{index}_t', precision=inp.type.precision)
elif self.get_attr('quantize') == 3:
self.add_output_variable(shape, dims, precision=IntegerPrecisionType(width=2))
threshold_hi = ternary_threshold / scale + threshold
threshold_lo = -ternary_threshold / scale + threshold
threshold_hi = np.floor(threshold_hi * 2**F) / 2**F
threshold_lo = np.floor(threshold_lo * 2**F) / 2**F
self.add_weights_variable(name='threshold_hi', var_name='th{index}', data=threshold_hi, type_name='threshold_hi_{index}_t', precision=inp.type.precision)
self.add_weights_variable(name='threshold_lo', var_name='tl{index}', data=threshold_lo, type_name='threshold_lo_{index}_t', precision=inp.type.precision)
from hls4ml.backends.fpga.fpga_layers import BatchNormalizationQuantizedTanh

batchnorm_quantized_tanh_config_template = """struct config{index} : nnet::batchnorm_quantized_tanh_config {{
static const unsigned n_in = {n_in};
Expand All @@ -51,6 +14,7 @@ def set_thresholds(self, scale, bias, ternary_threshold=0.5):
}};\n"""

batchnorm_quantized_tanh_function_template = 'nnet::normalize_{quantize}_tanh<{input_t}, {config}>({input}, {output}, {threshold});'

bn_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h']

class BatchNormalizationQuantizedTanhConfigTemplate(LayerConfigTemplate):
Expand Down
9 changes: 7 additions & 2 deletions hls4ml/backends/quartus/passes/core_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef {index_t.name} index_t;

template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

dense_function_template = 'nnet::dense_{strategy}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
Expand All @@ -44,7 +47,7 @@ def format(self, node):
params = self._default_config_params(node)
params['nzeros'] = node.get_weights('weight').nzeros
params['nonzeros'] = node.get_weights('weight').nonzeros
params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)

return self.template.format(**params)

Expand All @@ -71,6 +74,8 @@ def format(self, node):
static const bool store_weights_in_bram = false;
typedef {bias_t.name} bias_t;
typedef {scale_t.name} scale_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
Expand All @@ -85,7 +90,7 @@ def __init__(self):
def format(self, node):
params = self._default_config_params(node)
params['n_in'] = node.get_input_variable().size_cpp()
params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision)
params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision)

return self.template.format(**params)

Expand Down
29 changes: 29 additions & 0 deletions hls4ml/backends/quartus/passes/quantization_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from hls4ml.backends.backend import get_backend
from hls4ml.model.optimizer.passes.qkeras import ApplyAlpha
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate
from hls4ml.backends.quartus.passes.core_templates import batchnorm_config_template, batchnorm_function_template, batchnorm_include_list

class ApplyAlphaConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(ApplyAlpha)
self.template = batchnorm_config_template

def format(self, node):
params = self._default_config_params(node)
params['n_in'] = node.get_input_variable().size_cpp()
params['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('scale').type.precision)

return self.template.format(**params)

class ApplyAlphaFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(ApplyAlpha, include_header=batchnorm_include_list)
self.template = batchnorm_function_template

def format(self, node):
params = self._default_function_params(node)
params['scale'] = node.get_weights('scale').name
params['bias'] = node.get_weights('bias').name

return self.template.format(**params)

10 changes: 9 additions & 1 deletion hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def _register_flows(self):
]
quartus_types_flow = register_flow('specific_types', quartus_types, requires=[init_flow], backend=self.name)

quantization_passes = [
'quartus:merge_batch_norm_quantized_tanh',
'quartus:quantize_dense_output',
'fuse_consecutive_batch_normalization',
]
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)


templates = self._get_layer_templates()
template_flow = register_flow('apply_templates', templates, requires=[init_flow], backend=self.name)

Expand All @@ -60,7 +68,7 @@ def _register_flows(self):
else:
extras_flow = None

ip_flow_requirements = ['optimize', init_flow, quartus_types_flow, extras_flow, template_flow]
ip_flow_requirements = ['optimize', init_flow, quantization_flow, quartus_types_flow, extras_flow, template_flow]
ip_flow_requirements = list(filter(None, ip_flow_requirements))

self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)
Expand Down
14 changes: 10 additions & 4 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#define NNET_BATCHNORM_H_

#include "nnet_common.h"
#include "nnet_helpers.h"
#include "nnet_mult.h"

namespace nnet {

Expand All @@ -40,6 +42,10 @@ struct batchnorm_config
static const bool store_weights_in_bram = false;
static const unsigned n_zeros = 0;
// partitioning arrays cyclically to go with roll factors?

// Default multiplication
template<class x_T, class y_T>
using product = nnet::product::mult<x_T, y_T>;
};

template<class data_T, class res_T, typename CONFIG_T>
Expand All @@ -54,12 +60,12 @@ void normalize(
Result:
#pragma unroll
for (int ires = 0; ires < CONFIG_T::n_in; ires++) {
// TODO - Explore MULADD instruction in HLS - less clock cycles
if (CONFIG_T::n_filt==-1) {
res[ires] = data[ires] * scale[ires] + bias[ires];
}
else {
res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[ires]) + bias[ires];
} else {
int norm_index = ires%CONFIG_T::n_filt;
res[ires] = data[ires] * scale[norm_index] + bias[norm_index];
res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[norm_index]) + bias[norm_index];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//
// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks
//
// Copyright (C) 2017 EJ Kreinar
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//

/*
* PLACEHOLDER - The common pass bn_quant.py includes both parallel and streaming BN; streaming is currently not supported in Quartus
*/

#ifndef NNET_BATCHNORM_STREAM_H_
#define NNET_BATCHNORM_STREAM_H_

#include "nnet_common.h"
#include "nnet_helpers.h"
#include "nnet_mult.h"

namespace nnet {}

#endif
55 changes: 10 additions & 45 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#define NNET_DENSE_LARGE_H_

#include "nnet_common.h"
#include "nnet_helpers.h"
#include "nnet_mult.h"

namespace nnet {

Expand Down Expand Up @@ -49,50 +51,11 @@ struct dense_config
static const bool store_weights_in_bram = false;
static const unsigned n_zeros = 0;
// partitioning arrays cyclically to go with roll factors?
};

inline ac_int<1, false> product(ac_int<1, false> a, ac_int<1, false> w)
{
// specialisation for 1-bit weights and incoming data
return (a == w);
}

template<class data_T>
auto product(data_T a, ac_int<1, false> w) -> decltype(-a)
{
// Specialisation for 1-bit weights, arbitrary data
if (w == 0) return -a;
else return a;
}

template<class data_T>
auto product(data_T a, ac_int<2, true> w) -> decltype(-a)
{
// Specialisation for 2-bit weights, arbitrary data
if (w == 0) return 0;
else if(w == -1) return -a;
else return a; // if(w == 1)
}

template<class data_T, class weight_T>
auto product(data_T a, weight_T w) -> decltype(a*w)
{
// 'Normal' product
return a * w;
}

template<class data_T, class res_T, typename CONFIG_T>
inline typename std::enable_if<std::is_same<data_T, ac_int<1, false>>::value
and std::is_same<typename CONFIG_T::weight_t, ac_int<1, false>>::value, ac_int<nnet::ceillog2(CONFIG_T::n_in) + 2, true>>::type
cast(typename CONFIG_T::accum_t x){
return (ac_int<nnet::ceillog2(CONFIG_T::n_in) + 2, true>) (x - CONFIG_T::n_in / 2) * 2;
}

template<class data_T, class res_T, typename CONFIG_T>
inline typename std::enable_if<(not std::is_same<data_T, ac_int<1, false>>::value), res_T>::type
cast(typename CONFIG_T::accum_t x){
return (res_T) x;
}
// Default multiplication
template<class x_T, class y_T>
using product = nnet::product::mult<x_T, y_T>;
};

template<class data_T, class res_T, typename CONFIG_T>
void dense_rf_gt(
Expand Down Expand Up @@ -134,7 +97,8 @@ void dense_rf_gt(
uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im;
if (w_index >= CONFIG_T::reuse_factor_rounded*CONFIG_T::block_factor_rounded) continue;
int data_index = d_index[ir][im];
tmp_acc[im] = product(data[data_index], weights[w_index]);
// Modified this
tmp_acc[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[data_index], weights[w_index]);
}
hls_register typename CONFIG_T::accum_t mult[CONFIG_T::multiplier_limit];
ResetMult:
Expand Down Expand Up @@ -188,7 +152,8 @@ void dense_rf_lt(
for (int im = 0, in_index = ir; im < CONFIG_T::block_factor; im++) {
uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im;
if (ir + CONFIG_T::reuse_factor * im >= CONFIG_T::n_in*CONFIG_T::n_out) continue;
mult[im] = product(data[in_index], weights[w_index]);
// Modified this
mult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
in_index += CONFIG_T::reuse_factor;
if (in_index >= CONFIG_T::n_in) in_index = ir;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void dense_compressed(
for(int im = 0; im < CONFIG_T::compressed_block_factor; im++) {
uint32 w = ir + CONFIG_T::reuse_factor * im;
//if (w >= CONFIG_T::reuse_factor*CONFIG_T::compressed_block_factor) continue;
mult[im] = product<data_T, decltype(weights[w].weight), typename CONFIG_T::accum_t>(inputs[0][im], weights[w].weight);
typename CONFIG_T::accum_t prod =
mult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(inputs[0][im], weights[w].weight);
#pragma unroll
for (int is = 0; is < CONFIG_T::reuse_factor-1; is++) {
inputs[is][im] = inputs[is+1][im];
Expand Down
Loading