diff --git a/hls4ml/backends/quartus/passes/convolution_templates.py b/hls4ml/backends/quartus/passes/convolution_templates.py index deae5407a0..5630ad507e 100644 --- a/hls4ml/backends/quartus/passes/convolution_templates.py +++ b/hls4ml/backends/quartus/passes/convolution_templates.py @@ -59,7 +59,7 @@ """ conv1d_function_template = 'nnet::conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' -conv1d_include_list = ['nnet_utils/nnet_conv1d.h'] +conv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_conv1d_stream.h'] class Conv1DConfigTemplate(LayerConfigTemplate): def __init__(self): @@ -134,7 +134,7 @@ def format(self, node): }};\n""" conv2d_function_template = 'nnet::conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' -conv2d_include_list = ['nnet_utils/nnet_conv2d.h'] +conv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_conv2d_stream.h'] class Conv2DConfigTemplate(LayerConfigTemplate): def __init__(self): diff --git a/hls4ml/backends/quartus/passes/convolution_winograd.py b/hls4ml/backends/quartus/passes/convolution_winograd.py index 2259dc02f6..f7c67499a5 100644 --- a/hls4ml/backends/quartus/passes/convolution_winograd.py +++ b/hls4ml/backends/quartus/passes/convolution_winograd.py @@ -15,7 +15,9 @@ def match(self, node): weights_transformed = node.get_attr('_weights_transposed', False) == True # User opted for Winograd - implementation_is_winograd = node.get_attr('implementation', 'combination') == 'combination' or node.get_attr('implementation', 'combination') == 'winograd' + implementation_is_winograd = node.get_attr('implementation', 'combination') == 'combination' or node.get_attr('implementation', 'combination') == 'winograd' + + parallel_io_type = node.model.config.get_config_value('IOType') == 'io_parallel' # Winograd algorithm-specific conditions if isinstance(node, Conv1D): @@ -29,7 +31,7 @@ def match(self, node): # HLS Compiler fails to pipeline the entire component if Winograd loop only executes once loop_itr_gt_one = node.get_attr('out_width') > 2 - winograd_conditions = filter_size_matches and stride_is_one and loop_itr_gt_one + winograd_conditions = filter_size_matches and stride_is_one and loop_itr_gt_one and parallel_io_type elif isinstance(node, (Conv2D)): # Winograd only applies to specific kernel sizes @@ -44,7 +46,7 @@ def match(self, node): padding_is_equal = node.get_attr('pad_top', 0) == node.get_attr('pad_bottom', 0) and node.get_attr('pad_left', 0) == node.get_attr('pad_right', 0) - winograd_conditions = filter_size_matches and stride_is_one and padding_is_equal and loop_itr_gt_one + winograd_conditions = filter_size_matches and stride_is_one and padding_is_equal and loop_itr_gt_one and parallel_io_type else: winograd_conditions = False diff --git a/hls4ml/backends/quartus/passes/pointwise.py b/hls4ml/backends/quartus/passes/pointwise.py index a233168d78..bc30abc128 100644 --- a/hls4ml/backends/quartus/passes/pointwise.py +++ b/hls4ml/backends/quartus/passes/pointwise.py @@ -58,7 +58,8 @@ class OptimizePointwiseConv(OptimizerPass): def match(self, node): return node.class_name in ('Conv1D', 'Conv2D') and \ node.get_attr('filt_height', 1) == 1 and \ - node.get_attr('filt_width') == 1 + node.get_attr('filt_width') == 1 and \ + node.model.config.get_config_value('IOType') == 'io_parallel' def transform(self, model, node): dim = node.__class__.__name__[-2:] # '1D' or '2D' @@ -66,8 +67,6 @@ def transform(self, model, node): if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=(0,1)) pw_node.weights['bias'].data = node.weights['bias'].data - # pw_node.weights['bias'].data = node.weights['bias'].data - print("Here") model.replace_node(node, pw_node) - return True \ No newline at end of file + return True diff --git a/hls4ml/backends/quartus/passes/pooling_templates.py b/hls4ml/backends/quartus/passes/pooling_templates.py index 3bd9ed1025..308e5f9bb0 100644 --- a/hls4ml/backends/quartus/passes/pooling_templates.py +++ b/hls4ml/backends/quartus/passes/pooling_templates.py @@ -9,13 +9,18 @@ static const unsigned n_in = {n_in}; static const unsigned n_out = {n_out}; + static const unsigned filt_width = {pool_width}; static const unsigned n_filt = {n_filt}; + static const unsigned n_chan = {n_filt}; + + static const unsigned in_width = {n_in}; static const unsigned pad_left = {pad_left}; static const unsigned pad_right = {pad_right}; static const nnet::Pool_Op pool_op = nnet::{pool_op}; + typedef {accum_t.name} accum_t; }};\n""" pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{ @@ -24,6 +29,8 @@ static const unsigned pool_height = {pool_height}; static const unsigned pool_width = {pool_width}; + static const unsigned filt_height = {pool_height}; + static const unsigned filt_width = {pool_width}; static const unsigned in_height = {in_height}; static const unsigned in_width = {in_width}; @@ -31,6 +38,7 @@ static const unsigned out_width = {out_width}; static const unsigned n_filt = {n_filt}; + static const unsigned n_chan = {n_filt}; static const unsigned pad_top = {pad_top}; static const unsigned pad_bottom = {pad_bottom}; @@ -38,12 +46,14 @@ static const unsigned pad_right = {pad_right}; static const nnet::Pool_Op pool_op = nnet::{pool_op}; + typedef {accum_t.name} accum_t; }};\n""" global_pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{ static const unsigned n_in = {n_in}; static const unsigned n_filt = {n_filt}; static const nnet::Pool_Op pool_op = nnet::{pool_op}; + typedef {accum_t.name} accum_t; }};\n""" global_pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{ @@ -51,6 +61,7 @@ static const unsigned in_width = {in_width}; static const unsigned n_filt = {n_filt}; static const nnet::Pool_Op pool_op = nnet::{pool_op}; + typedef {accum_t.name} accum_t; }};\n""" pooling1d_function_template = 'nnet::pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' @@ -58,7 +69,7 @@ global_pooling1d_function_template = 'nnet::global_pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' global_pooling2d_function_template = 'nnet::global_pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' -pooling_include_list = ['nnet_utils/nnet_pooling.h'] +pooling_include_list = ['nnet_utils/nnet_pooling.h', 'nnet_utils/nnet_pooling_stream.h'] class PoolingConfigTemplate(LayerConfigTemplate): def __init__(self): diff --git a/hls4ml/backends/quartus/passes/reshaping_templates.py b/hls4ml/backends/quartus/passes/reshaping_templates.py index 3c8e52303a..fde574e7de 100644 --- a/hls4ml/backends/quartus/passes/reshaping_templates.py +++ b/hls4ml/backends/quartus/passes/reshaping_templates.py @@ -28,7 +28,7 @@ zeropad1d_function_template = 'nnet::zeropad1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' zeropad2d_function_template = 'nnet::zeropad2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' -padding_include_list = ['nnet_utils/nnet_padding.h'] +padding_include_list = ['nnet_utils/nnet_padding.h', 'nnet_utils/nnet_padding_stream.h'] class ZeroPaddingConfigTemplate(LayerConfigTemplate): def __init__(self): @@ -72,7 +72,7 @@ def format(self, node): }};\n""" resize_function_template = 'nnet::resize_{algorithm}<{input_t}, {config}>({input}, {output});' -resize_include_list = ['nnet_utils/nnet_resize.h'] +resize_include_list = ['nnet_utils/nnet_resize.h', 'nnet_utils/nnet_resize_stream.h'] class ResizeConfigTemplate(LayerConfigTemplate): def __init__(self): @@ -108,7 +108,7 @@ def format(self, node): }};\n""" transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});' -transpose_include_list = ['nnet_utils/nnet_transpose.h'] +transpose_include_list = ['nnet_utils/nnet_transpose.h', 'nnet_utils/nnet_transpose_stream.h'] class TransposeConfigTemplate(LayerConfigTemplate): def __init__(self): diff --git a/hls4ml/templates/quartus/firmware/defines.h b/hls4ml/templates/quartus/firmware/defines.h index fc28a415d5..6e9b243d83 100644 --- a/hls4ml/templates/quartus/firmware/defines.h +++ b/hls4ml/templates/quartus/firmware/defines.h @@ -50,5 +50,6 @@ using stream_out = ihc::stream_out; #define DIV_ROUNDUP(n,d) ((n + d - 1) / d) #define MIN(n,d) (n > d ? d : n) +#define MAX(n,d) (n < d ? d : n) #endif diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h index f244901c52..20790a390a 100755 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h @@ -131,10 +131,12 @@ enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3}; template inline unsigned softmax_stable_idx_from_real_val(const data_T x){ // Number of address bits for table - static constexpr int N = ceillog2(CONFIG_T::table_size); + static constexpr int N = ceillog2(CONFIG_T::table_size); // Slice the top N bits of the input - hls_register ac_int y = x.template slc(x.width-N-1); + hls_register ac_int y = x.template slc(x.width-N-1); + // If x is the most negative value, the slice will be 0, so we need to set the 0-th bit to ensure correctness + if (x != 0 && y == 0) y[0] = 1; return y.to_uint(); } @@ -158,11 +160,18 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){ Op_max op_max; hls_register data_T x_max = reduce>(data, op_max); + // For the diffs, use the same type as the input but force rounding and saturation + hls_register ac_fixed d_xi_xmax[CONFIG_T::n_in]; + #pragma unroll + for(unsigned i = 0; i < CONFIG_T::n_in; i++){ + d_xi_xmax[i] = data[i] - x_max; + } + // Calculate all the e^x's hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; #pragma unroll for(unsigned i = 0; i < CONFIG_T::n_in; i++) { - exp_res[i] = exp_table[softmax_stable_idx_from_real_val(data[i] - x_max)]; + exp_res[i] = exp_table[softmax_stable_idx_from_real_val(d_xi_xmax[i])]; } // Explicitly sum previously calculated exponentials with an adder tree diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_stream.h new file mode 100644 index 0000000000..6ef76e2f87 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_stream.h @@ -0,0 +1,178 @@ +#ifndef NNET_CONV1D_STREAM_H_ +#define NNET_CONV1D_STREAM_H_ + +#include "nnet_types.h" +#include "nnet_dense.h" + +namespace nnet { + +/* +* void kernel_shift(shift_buffer, kernel_window) +* +* Args: +* shift_buffer - array elements popped from the line the buffer during the shift line buffer operation +* kernel_window - array of values from the input curently being convolved with the kernel +* +* Values from shift_buffer are inserted into kernel_window, updating the values to be convolved +*/ +template +void kernel_shift_1d( + typename data_T::value_type shift_buffer[CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan] +) { + /* + * Manually shift kernel_window by one step to the left + * Not possible to use nnet::shift_reg as the kernel window is convolved with the kernel weights using dense matrix multiplication + * Dense matrix multiplication is only implemented for arrays + * However, provided certain timing constrains are met, Intel HLS automatically infers a shift operation and implements kernel_window as a shift register + * To verify, see synthesis report in report.html > Area Analysis of System + */ + KernelShiftWidth: + #pragma unroll + for (int col = 0; col < CONFIG_T::filt_width - 1; col++) { + KernelShiftChannel: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + kernel_window[col * CONFIG_T::n_chan + channel] = kernel_window[(col + 1) * CONFIG_T::n_chan + channel]; + } + } + + // Insert shift_buffer values into the last column of the kernel window + KernelPushChannel: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + kernel_window[(CONFIG_T::filt_width - 1) * CONFIG_T::n_chan + channel] = shift_buffer[channel]; + } +} + +/* +* void shift_line_buffer(in_element, line_buffer, shift_buffer) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* line_buffer - chained array of shift registers, one for each row of the kernel and channel +* shift_buffer - array elements popped from the line the buffer during the shift operation +* +* Values from in_element are inserted into the line buffer, causing all other elements to be shifted by one +* Popped elements are later used to update the kernel window, during the kernel_shift operation +*/ +template +void shift_line_buffer_1d( + const data_T &in_elem, + nnet::shift_reg line_buffer[CONFIG_T::n_chan], + typename data_T::value_type shift_buffer[CONFIG_T::n_chan] +) { + // For every channel, insert the incoming pixel at end of the shift buffer + UpdateBuffer: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + shift_buffer[channel] = in_elem[channel]; + } +} + +/* +* void compute_output_buffer(in_element, res_stream, line_buffer, kernel_window, weights, biases) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* res_stream - output stream, passed by reference to allow direct writing +* line_buffer - chained array of shift registers, one for each row of the kernel and channel +* kernel_window - array of values from the input curently convolved with the kernel +* weights - Conv1D layer weights +* biases - Conv1D layer biases +* +* Function executes 4 steps: +* (1) Shift line buffer - updates the contents of the chained shift registers, inserting the new inputs and removing last elements +* (2) Kernel shift - updates the elements of the kernel window, by storing the new inputs and popped elements from the line buffer +* (3) Matrix mulitplication - performs dense matrix multiplication between the current input window and kernel weights +* (4) Counter housekeeping - keeps track of current pixel and stride +*/ +template +void compute_output_buffer_1d( + const data_T &in_elem, + stream &res_stream, + nnet::shift_reg line_buffer[CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan], + const typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) { + // Thresholds + static constexpr int lShiftX = CONFIG_T::filt_width - 1; + + // X position pixel + static int pX = 0; + + // X strides + static int sX = 0; + + // Step 1 - Shift line buffer + hls_register typename data_T::value_type shift_buffer[CONFIG_T::n_chan]; + nnet::shift_line_buffer_1d(in_elem, line_buffer, shift_buffer); + + // Step 2 - Kernel shift + nnet::kernel_shift_1d(shift_buffer, kernel_window); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && pX > (lShiftX - 1)) { + // Step 3 - Dense matrix multiplication + hls_register typename res_T::value_type res_out[CONFIG_T::n_filt]; + dense_resource(kernel_window, res_out, weights, biases); + + // Write result to output stream + hls_register res_T res_pack; + CastLoop: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_filt; channel++) { + res_pack[channel] = res_out[channel]; + } + res_stream.write(res_pack); + } + + // Reached end of image + if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right)) { + pX = 0; + sX = 0; + // Move to the right + } else { + pX++; + sX = ((sX - lShiftX) == 0) ? (sX - CONFIG_T::stride_width + 1) : (sX + 1); + } +} + + +template +void conv_1d_cl( + stream &data, + stream &res, + const typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) { + // Line buffer and kernel window + hls_register static nnet::shift_reg line_buffer[CONFIG_T::n_chan]; + hls_register static typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan]; + + // An array of length CONFIG_T::n_chan, with elements set to zero (padding for each channel) + static const data_T padds(0); + + // Input image left-side padding + PaddingLeftWidth: + for (int col = 0; col < CONFIG_T::pad_left; col++) { + compute_output_buffer_1d(padds, res, line_buffer, kernel_window, weights, biases); + } + + // Read input image + ReadInputWidth: + for (int col = 0; col < CONFIG_T::in_width; col++) { + compute_output_buffer_1d(data.read(), res, line_buffer, kernel_window, weights, biases); + } + + // Input image right-side padding + PaddingRightWidth: + for (int col = 0; col < CONFIG_T::pad_right; col++) { + compute_output_buffer_1d(padds, res, line_buffer, kernel_window, weights, biases); + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_stream.h new file mode 100644 index 0000000000..d60397ba9e --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_stream.h @@ -0,0 +1,236 @@ +#ifndef NNET_CONV2D_STREAM_H_ +#define NNET_CONV2D_STREAM_H_ + +#include "nnet_types.h" +#include "nnet_dense.h" + +namespace nnet { + +/* +* void kernel_shift(shift_buffer, kernel_window) +* +* Args: +* shift_buffer - array elements popped from the line the buffer during the shift line buffer operation +* kernel_window - array of values from the input curently being convolved with the kernel +* +* Values from shift_buffer are inserted into kernel_window, updating the values to be convolved +*/ +template +void kernel_shift_2d( + typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::filt_height * CONFIG_T::n_chan] +) { + /* + * Manually shift kernel_window by one step to the left + * Not possible to use nnet::shift_reg as the kernel window is convolved with the kernel weights using dense matrix multiplication + * Dense matrix multiplication is only implemented for arrays + * However, provided certain timing constrains are met, Intel HLS automatically infers a shift operation and implements kernel_window as a shift register + * To verify, see synthesis report in report.html > Area Analysis of System + */ + KernelShiftWidth: + #pragma unroll + for (int col = 0; col < CONFIG_T::filt_width - 1; col++) { + KernelShiftHeight: + #pragma unroll + for (int row = 0; row < CONFIG_T::filt_height; row++) { + KernelShiftChannel: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + kernel_window[row * CONFIG_T::filt_width * CONFIG_T::n_chan + col * CONFIG_T::n_chan + channel] = kernel_window[row * CONFIG_T::filt_width * CONFIG_T::n_chan + (col + 1) * CONFIG_T::n_chan + channel]; + } + } + } + + // Insert shift_buffer values into the last column of the kernel window + KernelPushHeight: + #pragma unroll + for (int col = 0; col < CONFIG_T::filt_height; col++) { + KernelPushChannel: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + kernel_window[(CONFIG_T::filt_width - 1) * CONFIG_T::n_chan + col * CONFIG_T::filt_width * CONFIG_T::n_chan + channel] = shift_buffer[col][channel]; + } + } +} + +/* +* void shift_line_buffer(in_element, line_buffer, shift_buffer) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* line_buffer - chained array of shift registers, one for each row of the kernel and channel +* shift_buffer - array elements popped from the line the buffer during the shift operation +* +* Values from in_element are inserted into the line buffer, causing all other elements to be shifted by one +* Popped elements are later used to update the kernel window, during the kernel_shift operation +*/ +template +void shift_line_buffer_2d( + const data_T &in_elem, + nnet::shift_reg line_buffer[CONFIG_T::filt_height - 1][CONFIG_T::n_chan], + typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan] +) { + // For every channel, insert the incoming pixel at end of the shift buffer + UpdateBuffer: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + shift_buffer[CONFIG_T::filt_height - 1][channel] = in_elem[channel]; + } + + // Shift line buffer and save popped values to shift buffer + LineBufferDataIn: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + LineBufferShift: + #pragma unroll + for (unsigned col = 1; col < CONFIG_T::filt_height; col++) { + // Shift the line buffer, return the popped pixel + typename data_T::value_type pop = line_buffer[col - 1][channel].shift(shift_buffer[CONFIG_T::filt_height - col][channel]); + + // Place popped pixed into the shift buffer, one row above + shift_buffer[CONFIG_T::filt_height - col - 1][channel] = pop; + } + } +} + +/* +* void compute_output_buffer(in_element, res_stream, line_buffer, kernel_window, weights, biases) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* res_stream - output stream, passed by reference to allow direct writing +* line_buffer - chained array of shift registers, one for each row of the kernel and channel +* kernel_window - array of values from the input curently convolved with the kernel +* weights - Conv1D/Conv2D layer weights +* biases - Conv1D/Conv2D layer biases +* +* Function executes 4 steps: +* (1) Shift line buffer - updates the contents of the chained shift registers, inserting the new inputs and removing last elements +* (2) Kernel shift - updates the elements of the kernel window, by storing the new inputs and popped elements from the line buffer +* (3) Matrix mulitplication - performs dense matrix multiplication between the current input window and kernel weights +* (4) Counter housekeeping - keeps track of current pixel and stride +*/ +template +void compute_output_buffer_2d( + const data_T &in_elem, + stream &res_stream, + nnet::shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + const typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) { + // Thresholds + static constexpr int lShiftX = CONFIG_T::filt_width - 1; + static constexpr int lShiftY = CONFIG_T::filt_height - 1; + + // X, Y position pixels + static int pX = 0; + static int pY = 0; + + // X, Y strides + static int sX = 0; + static int sY = 0; + + // Step 1 - Shift line buffer + hls_register typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan]; + nnet::shift_line_buffer_2d(in_elem, line_buffer, shift_buffer); + + // Step 2 - Kernel shift + nnet::kernel_shift_2d(shift_buffer, kernel_window); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > (lShiftY - 1) && pX > (lShiftX - 1)) { + // Step 3 - Dense matrix multiplication + hls_register typename res_T::value_type res_out[CONFIG_T::n_filt]; + dense_resource(kernel_window, res_out, weights, biases); + + // Write result to output stream + hls_register res_T res_pack; + CastLoop: + #pragma unroll + for (int channel = 0; channel < CONFIG_T::n_filt; channel++) { + res_pack[channel] = res_out[channel]; + } + res_stream.write(res_pack); + } + + // Reached end of image + if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right) && (pY + 1) == (CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom)) { + pX = 0; + sX = 0; + pY = 0; + sY = 0; + // Reached end of row + } else if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right)) { + pX = 0; + sX = 0; + pY++; + sY = ((sY - lShiftY) == 0) ? (sY - CONFIG_T::stride_height + 1) : (sY + 1); + // Same row, same colum, therefore, move to the right + } else { + pX++; + sX = ((sX - lShiftX) == 0) ? (sX - CONFIG_T::stride_width + 1) : (sX + 1); + } +} + +template +void conv_2d_cl( + stream &data, + stream &res, + const typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] +) { + + // Line buffer and kernel window + hls_register static nnet::shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)][CONFIG_T::n_chan]; + hls_register static typename data_T::value_type kernel_window[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + + // An array of length CONFIG_T::n_chan, with elements set to zero (padding for each channel) + static const data_T padds(0); + + // Padding above input image + PaddingTopHeight: + #pragma loop_coalesce 2 + for (int row = 0; row < CONFIG_T::pad_top; row++) { + PaddingTopWidth: + for (int col = 0; col < CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right; col++) { + compute_output_buffer_2d(padds, res, line_buffer, kernel_window, weights, biases); + } + } + + ReadInputHeight: + #pragma loop_coalesce 2 + for (int row = 0; row < CONFIG_T::in_height; row++) { + // Input image left-side padding + PaddingLeftWidth: + for (int col = 0; col < CONFIG_T::pad_left; col++) { + compute_output_buffer_2d(padds, res, line_buffer, kernel_window, weights, biases); + } + + // Read input image + ReadInputWidth: + for (int col = 0; col < CONFIG_T::in_width; col++) { + compute_output_buffer_2d(data.read(), res, line_buffer, kernel_window, weights, biases); + } + + // Input image right-side padding + PaddingRightWidth: + for (int col = 0; col < CONFIG_T::pad_right; col++) { + compute_output_buffer_2d(padds, res, line_buffer, kernel_window, weights, biases); + } + } + + // Padding below input image + PaddingBottomHeight: + #pragma loop_coalesce 2 + for (int row = 0; row < CONFIG_T::pad_bottom; row++) { + PaddingBottomWidth: + for (int col = 0; col < CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right; col++) { + compute_output_buffer_2d(padds, res, line_buffer, kernel_window, weights, biases); + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_helpers.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_helpers.h index 8244b1ac1a..dfdaa5c8d7 100755 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_helpers.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_helpers.h @@ -29,26 +29,6 @@ #include #include -#ifndef __INTELFPGA_COMPILER__ -#include "stream.h" -template -using stream = nnet::stream; -template -using stream_in = nnet::stream; -template -using stream_out = nnet::stream; -#else -#include "HLS/hls.h" -#include "HLS/ac_int.h" -#include "HLS/ac_fixed.h" -template -using stream = ihc::stream; -template -using stream_in = ihc::stream_in; -template -using stream_out = ihc::stream_out; -#endif - namespace nnet { template @@ -109,6 +89,17 @@ void save_output_array(data_T *data, save_T *ptr, size_t layer_size) { } } +template +void save_output_array(stream &data, save_T *ptr, size_t layer_size) { + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + ptr[i * data_T::size + j] = static_cast(ctype[j].to_double()); + } + data.write(ctype); + } +} + // We don't want to include save_T in this function because it will be inserted into myproject.cpp // so a workaround with element size is used template @@ -141,6 +132,40 @@ void save_layer_output(data_T *data, const char *layer_name, size_t layer_size) } } +template +void save_layer_output(stream &data, const char *layer_name, size_t layer_size) { + if (!trace_enabled) return; + + if (trace_outputs) { + if (trace_outputs->count(layer_name) > 0) { + if (trace_type_size == 4) { + save_output_array(data, (float *) (*trace_outputs)[layer_name], layer_size); + } else if (trace_type_size == 8) { + save_output_array(data, (double *) (*trace_outputs)[layer_name], layer_size); + } else { + std::cout << "Unknown trace type!" << std::endl; + } + } else { + std::cout << "Layer name: " << layer_name << " not found in debug storage!" << std::endl; + } + } else { + std::ostringstream filename; + filename << "./tb_data/" << layer_name << "_output.log"; //TODO if run as a shared lib, path should be ../tb_data + std::fstream out; + out.open(filename.str(), std::ios::app); + assert(out.is_open()); + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + out << ctype[j] << " "; + } + data.write(ctype); + } + out << std::endl; + out.close(); + } +} + } #endif diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_padding_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_padding_stream.h new file mode 100644 index 0000000000..78e8fb4be7 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_padding_stream.h @@ -0,0 +1,87 @@ +#ifndef NNET_PADDING_STREAM_H_ +#define NNET_PADDING_STREAM_H_ + +namespace nnet { + +template +inline void fill_zero(stream &res) { + hls_register res_T res_part; + #pragma unroll + for (int i = 0; i < CONFIG_T::n_chan; i++) { + res_part[i] = 0; + } + res.write(res_part); +} + +template +inline void fill_data(stream &data, stream &res) { + hls_register data_T data_part = data.read(); + hls_register res_T res_part; + #pragma unroll + for (int i = 0; i < CONFIG_T::n_chan; i++) { + res_part[i] = data_part[i]; + } + res.write(res_part); +} + +template +void zeropad1d_cl(stream &data, stream &res) { + PadLeft: + for (int i = 0; i < CONFIG_T::pad_left; i++) { + fill_zero(res); + } + + CopyMain: + for (int i = 0; i < CONFIG_T::in_width; i++) { + fill_data(data, res); + } + + PadRight: + for (int i = 0; i < CONFIG_T::pad_right; i++) { + fill_zero(res); + } +} + +template +void zeropad2d_cl(stream &data, stream &res) { + PadTop: + #pragma loop_coalesce 2 + for (int i = 0; i < CONFIG_T::pad_top; i++) { + PadTopWidth: + for (int j = 0; j < CONFIG_T::out_width; j++) { + fill_zero(res); + } + } + + PadMain: + #pragma loop_coalesce 2 + for (int i = 0; i < CONFIG_T::in_height; i++) { + + PadLeft: + for (int j = 0; j < CONFIG_T::pad_left; j++) { + fill_zero(res); + } + + CopyMain: + for (int j = 0; j < CONFIG_T::in_width; j++) { + fill_data(data, res); + } + + PadRight: + for (int j = 0; j < CONFIG_T::pad_right; j++) { + fill_zero(res); + } + } + + PadBottom: + for (int i = 0; i < CONFIG_T::pad_bottom; i++) { + PadBottomWidth: + for (int j = 0; j < CONFIG_T::out_width; j++) { + fill_zero(res); + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_pooling_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_pooling_stream.h new file mode 100644 index 0000000000..a8c9c0f491 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_pooling_stream.h @@ -0,0 +1,321 @@ +#ifndef NNET_POOLING_STREAM_H_ +#define NNET_POOLING_STREAM_H_ + +#include "nnet_pooling.h" +#include "nnet_conv1d_stream.h" +#include "nnet_conv2d_stream.h" +#include "nnet_types.h" + +namespace nnet { + +/* +* void compute_pool_buffer_1d(in_element, res_stream, line_buffer, kernel_window) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* res_stream - output stream, passed by reference to allow direct writing +* line_buffer - chained array of shift registers, one for each row of the pool and channel +* kernel_window - array of values from the input curently being pooled +* +* Function executes 4 steps: +* (1) Shift line buffer - updates the contents of the chained shift registers, inserting the new inputs and removing last elements +* (2) Kernel shift - updates the elements of the kernel window, by storing the new inputs and popped elements from the line buffer +* (3) Pooling - performs dense matrix multiplication between the current input window and kernel weights +* (4) Counter housekeeping - performs the required pooling operation +* +*/ +template +void compute_pool_buffer_1d( + const data_T &in_elem, + stream &res_stream, + nnet::shift_reg line_buffer[CONFIG_T::n_filt], + typename data_T::value_type kernel_window[CONFIG_T::pool_width * CONFIG_T::n_filt] +) { + // Thresholds + static constexpr int lShiftX = CONFIG_T::pool_width - 1; + + // X position pixels + static int pX = 0; + + // X strides + static int sX = 0; + + // Step 1 - Shift line buffer + hls_register typename data_T::value_type shift_buffer[CONFIG_T::n_filt]; + nnet::shift_line_buffer_1d(in_elem, line_buffer, shift_buffer); + + // Step 2 - Kernel shift + nnet::kernel_shift_1d(shift_buffer, kernel_window); + + // Check to see if we have a full pool window + if ((sX - lShiftX) == 0 && pX > (lShiftX - 1)) { + hls_register res_T res_pack; + + FiltLoop: + #pragma unroll + for(int filter = 0; filter < CONFIG_T::n_filt; filter++) { + hls_register typename data_T::value_type pool_window[CONFIG_T::pool_width]; + + // Retrieve data for current channel + PoolLoop: + #pragma unroll + for(int i = 0; i < CONFIG_T::pool_width; i++) { + pool_window[i] = kernel_window[i * CONFIG_T::n_filt + filter]; + } + + // Step 3 - Pooling + res_pack[filter] = static_cast(pool_op(pool_window)); + } + + // Write result to output stream + res_stream.write(res_pack); + } + + // Reached end of image + if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right)) { + pX = 0; + sX = 0; + // Move to the right + } else { + pX++; + sX = ((sX - lShiftX) == 0) ? (sX - CONFIG_T::stride_width + 1) : (sX + 1); + } +} + + +template +void pooling1d_cl(stream &data, stream &res) { + assert(CONFIG_T::pool_width == CONFIG_T::stride_width); + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + // Line buffer and kernel window + hls_register static nnet::shift_reg line_buffer[CONFIG_T::n_filt]; + hls_register static typename data_T::value_type kernel_window[CONFIG_T::pool_width * CONFIG_T::n_filt]; + + // Read input image + ReadInputWidth: + for (int col = 0; col < CONFIG_T::in_width; col++) { + compute_pool_buffer_1d(data.read(), res, line_buffer, kernel_window); + } +} + +/* +* void compute_pool_buffer_2d(in_element, res_stream, line_buffer, kernel_window) +* +* Args: +* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels +* res_stream - output stream, passed by reference to allow direct writing +* line_buffer - chained array of shift registers, one for each row of the pool and channel +* kernel_window - array of values from the input curently being pooled +* +* Function executes 4 steps: +* (1) Shift line buffer - updates the contents of the chained shift registers, inserting the new inputs and removing last elements +* (2) Kernel shift - updates the elements of the kernel window, by storing the new inputs and popped elements from the line buffer +* (3) Pooling - performs dense matrix multiplication between the current input window and kernel weights +* (4) Counter housekeeping - performs the required pooling operation +* +*/ +template +void compute_pool_buffer_2d( + const data_T &in_elem, + stream &res_stream, + nnet::shift_reg line_buffer[CONFIG_T::pool_height - 1][CONFIG_T::n_filt], + typename data_T::value_type kernel_window[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt] +) { + // Thresholds + static constexpr int lShiftX = CONFIG_T::pool_width - 1; + static constexpr int lShiftY = CONFIG_T::pool_height - 1; + + // X, Y position pixels + static int pX = 0; + static int pY = 0; + + // X, Y strides + static int sX = 0; + static int sY = 0; + + // Step 1 - Shift line buffer + hls_register typename data_T::value_type shift_buffer[CONFIG_T::pool_height][CONFIG_T::n_filt]; + nnet::shift_line_buffer_2d(in_elem, line_buffer, shift_buffer); + + // Step 2 - Kernel shift + nnet::kernel_shift_2d(shift_buffer, kernel_window); + + // Check to see if we have a full pool window + if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > (lShiftY - 1) && pX > (lShiftX - 1)) { + hls_register res_T res_pack; + + FiltLoop: + #pragma unroll + for(int filter = 0; filter < CONFIG_T::n_filt; filter++) { + hls_register typename data_T::value_type pool_window[CONFIG_T::pool_height * CONFIG_T::pool_width]; + + // Retrieve data for current channel + PoolLoop: + #pragma unroll + for(int i = 0; i < CONFIG_T::pool_height * CONFIG_T::pool_width; i++) { + pool_window[i] = kernel_window[i * CONFIG_T::n_filt + filter]; + } + + // Step 3 - Pooling + res_pack[filter] = static_cast(pool_op(pool_window)); + } + + // Write result to output stream + res_stream.write(res_pack); + } + + // Reached end of image + if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right) && (pY + 1) == (CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom)) { + pX = 0; + sX = 0; + pY = 0; + sY = 0; + // Reached end of row + } else if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right)) { + pX = 0; + sX = 0; + pY++; + sY = ((sY - lShiftY) == 0) ? (sY - CONFIG_T::stride_height + 1) : (sY + 1); + // Same row, same colum, therefore, move to the right + } else { + pX++; + sX = ((sX - lShiftX) == 0) ? (sX - CONFIG_T::stride_width + 1) : (sX + 1); + } +} + +template +void pooling2d_cl(stream &data, stream &res) { + assert(CONFIG_T::pool_height == CONFIG_T::stride_height && CONFIG_T::pool_width == CONFIG_T::stride_width); + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0); + + // Line buffer and kernel window + hls_register static nnet::shift_reg line_buffer[MAX(CONFIG_T::pool_height - 1,1)][CONFIG_T::n_filt]; + hls_register static typename data_T::value_type kernel_window[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt]; + + ReadInputHeight: + #pragma loop_coalesce 2 + for (int row = 0; row < CONFIG_T::in_height; row++) { + // Read input image + ReadInputWidth: + for (int col = 0; col < CONFIG_T::in_width; col++) { + compute_pool_buffer_2d(data.read(), res, line_buffer, kernel_window); + } + } +} + +/* +* A function used with Global Pooling +* Returns the value before pooling +* Max : Return the minimal possible value +* Avg : Return 0 +*/ +template +inline T init_pool_value() { + switch(op){ + case Max: { + T x = 0; + x[x.width - 1] = 1; + return x; + } + case Average: return 0; + } +} + +/* +* A function used with Global Pooling +* Updates the output pooling value +* Max : Return the maximum between the previous maximum and current input +* Avg : Returns the cumulative sum +*/ +template +inline T_y reduce_global_pool(T_y y, T_x x) { + if (op == Max) { + return (x > y) ? (T_y) x : y; + } else { + return (T_y) (x + y); + } +} + +/* +* A function used with Global Pooling +* For every filter, it updates the value by summing the current input (Average) or updating the maximum value (Max) +*/ +template +void compute_global_pool(const data_T& in_elem, typename CONFIG_T::accum_t data_input[CONFIG_T::n_filt]) { + #pragma unroll + for (unsigned i = 0; i < CONFIG_T::n_filt; i++) { + data_input[i] = reduce_global_pool(data_input[i], in_elem[i]); + } +} + +template +void global_pooling1d_cl(stream &data, stream &res) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + hls_register typename CONFIG_T::accum_t data_input[CONFIG_T::n_filt]; + + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + data_input[i] = init_pool_value(); + } + + for (int i = 0; i < CONFIG_T::n_in; i++) { + compute_global_pool(data.read(), data_input); + } + + hls_register res_T res_pack; + if (CONFIG_T::pool_op == Average) { + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + res_pack[i] = static_cast(data_input[i] / CONFIG_T::n_in); + } + } else { + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + res_pack[i] = static_cast(data_input[i]); + } + } + + res.write(res_pack); +} + +template +void global_pooling2d_cl(stream &data, stream &res) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0); + + hls_register typename CONFIG_T::accum_t data_input[CONFIG_T::n_filt]; + + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + data_input[i] = init_pool_value(); + } + + for (int i = 0; i < CONFIG_T::in_height; i++) { + for (int j = 0; j < CONFIG_T::in_width; j++) { + compute_global_pool(data.read(), data_input); + } + } + + hls_register res_T res_pack; + if (CONFIG_T::pool_op == Average) { + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + res_pack[i] = static_cast(data_input[i] / (CONFIG_T::in_width * CONFIG_T::in_height)); + } + } else { + #pragma unroll + for (int i = 0; i < CONFIG_T::n_filt; i++) { + res_pack[i] = static_cast(data_input[i]); + } + } + + res.write(res_pack); +} + + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_resize_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_resize_stream.h new file mode 100644 index 0000000000..c356a01a31 --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_resize_stream.h @@ -0,0 +1,57 @@ +#ifndef NNET_IMAGE_STREAM_H_ +#define NNET_IMAGE_STREAM_H_ + +#include "nnet_common.h" + +namespace nnet { + +template +void resize_nearest(stream &image, stream &resized) { + assert(CONFIG_T::new_height % CONFIG_T::height == 0); + assert(CONFIG_T::new_width % CONFIG_T::width == 0); + + constexpr unsigned ratio_height = CONFIG_T::new_height / CONFIG_T::height; + constexpr unsigned ratio_width = CONFIG_T::new_width / CONFIG_T::width; + + ImageHeight: + for (unsigned h = 0; h < CONFIG_T::height; h++) { + hls_register data_T data_in_row[CONFIG_T::width]; + + ImageWidth: + for (unsigned i = 0; i < CONFIG_T::width; i++) { + hls_register data_T in_data = image.read(); + + ImageChan: + #pragma unroll + for (unsigned j = 0; j < CONFIG_T::n_chan; j++) { + data_in_row[i][j] = in_data[j]; + } + } + + ResizeHeight: + for (unsigned i = 0; i < ratio_height; i++) { + + ImageWidth2: + for (unsigned l = 0; l < CONFIG_T::width; l++) { + + ResizeWidth: + for (unsigned j = 0; j < ratio_width; j++) { + + hls_register data_T out_data; + + ResizeChan: + #pragma unroll + for (unsigned k = 0; k < CONFIG_T::n_chan; k++) { + out_data[k] = data_in_row[l][k]; + } + + resized.write(out_data); + } + } + } + } +} + +} + +#endif diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_stream.h index 4b2c8a4859..45e821adc5 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_stream.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_stream.h @@ -1,5 +1,7 @@ -#ifndef NNET_STREAM_H -#define NNET_STREAM_H +#ifndef NNET_CLONE_H +#define NNET_CLONE_H + +#include "nnet_common.h" namespace nnet { @@ -31,6 +33,30 @@ void clone_stream(stream &data, stream &res1, stream &res2 } } +template +void clone_stream(stream &data, stream &res1, stream &res2, stream &res3) { + CloneLoop: + #pragma ii 1 + for (int i = 0; i < N / data_T::size; i++) { + data_T in_data = data.read(); + res_T out_data1; + res_T out_data2; + res_T out_data3; + + ClonePack: + #pragma unroll + for (int j = 0; j < data_T::size; j++) { + out_data1[j] = in_data[j]; + out_data2[j] = in_data[j]; + out_data3[j] = in_data[j]; + } + + res1.write(out_data1); + res2.write(out_data2); + res3.write(out_data3); + } +} + } #endif diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose.h index 9ed54958d5..920ac53cab 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose.h @@ -1,5 +1,5 @@ -#ifndef NNET_ARRAY_H_ -#define NNET_ARRAY_H_ +#ifndef NNET_TRANSPOSE_H_ +#define NNET_TRANSPOSE_H_ namespace nnet { @@ -29,7 +29,7 @@ void transpose_3d( res_T res[CONFIG_T::depth * CONFIG_T::height * CONFIG_T::width] ) { static constexpr unsigned dim_data[3] = { CONFIG_T::depth, CONFIG_T::height, CONFIG_T::width }; - static constexpr unsigned dim_res[3] = { dim_data[CONFIG_T::perm[0], dim_data[CONFIG_T::perm[1], dim_data[CONFIG_T::perm[2] }; + static constexpr unsigned dim_res[3] = { dim_data[CONFIG_T::perm[0]], dim_data[CONFIG_T::perm[1]], dim_data[CONFIG_T::perm[2]] }; int index_data[3] = {0}, index_res[3] = {0}; @@ -42,7 +42,7 @@ void transpose_3d( index_res[1] = index_data[CONFIG_T::perm[1]]; index_res[2] = index_data[CONFIG_T::perm[2]]; - data_t[index_res[0] * dim_res[1] * dim_res[2] + index_res[1] * dim_res[2] + index_res[2]] = static_cast(data[index_data[0] * dim_data[1] * dim_data[2] + index_data[1] * dim_data[2] + index_data[2]]); + res[index_res[0] * dim_res[1] * dim_res[2] + index_res[1] * dim_res[2] + index_res[2]] = static_cast(data[index_data[0] * dim_data[1] * dim_data[2] + index_data[1] * dim_data[2] + index_data[2]]); } } } diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose_stream.h new file mode 100644 index 0000000000..75ef67a87b --- /dev/null +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_transpose_stream.h @@ -0,0 +1,33 @@ +#ifndef NNET_TRANSPOSE_STREAM_H_ +#define NNET_TRANSPOSE_STREAM_H_ + +namespace nnet { + +template +void transpose_2d(stream &data, stream &res) { + hls_register typename data_T::value_type data_array[CONFIG_T::height * CONFIG_T::width]; + + for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / data_T::size; i++) { + hls_register data_T in_data = data.read(); + + #pragma unroll + for (int j = 0; j < data_T::size; j++) { + data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]); + } + } + + for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / res_T::size; i++) { + hls_register res_T out_data; + + #pragma unroll + for (int j = 0; j < res_T::size; j++) { + out_data[j] = typename res_T::value_type(data_array[j * data_T::size + i]); + } + + res.write(out_data); + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_types.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_types.h index 5913354ff7..cc0293c3c4 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_types.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_types.h @@ -15,6 +15,15 @@ struct array { T data[N]; + array() {} + + array(T x) { + #pragma unroll + for (int i = 0 ; i < N ; i++) { + data[i] = x; + } + } + T& operator[](size_t pos) { return data[pos]; } @@ -37,6 +46,40 @@ struct array { } }; +/* +* HLS Shift Register Implementation +* To verify a shift register is used in hardware, go to report.html > Area Analysis of System +* Unrolling the shift loop minimizes resource usage and latency at the same time +* The shift loop should be either fully unrolled or not unrolled at all +* Unrolling with a specific unroll factor or pipelining with certain ii's, can cause an irregular access pattern, which wouldn't allow shift register usage in RTL +*/ +template +struct shift_reg { + private: + T data[N]; + + public: + // Default constructor + shift_reg() {} + + // Shift queue, insert new element and return element from the front + T shift(T inp) { + T out = data[N-1]; + + #pragma unroll + for(int i = N - 1; i > 0; i--) { + data[i] = data[i-1]; + } + data[0] = inp; + + return out; + } + + T read(int pos) { + return data[pos]; + } +}; + } -#endif +#endif \ No newline at end of file diff --git a/hls4ml/writer/quartus_writer.py b/hls4ml/writer/quartus_writer.py index 58cfe937be..b9c3d4c9cc 100644 --- a/hls4ml/writer/quartus_writer.py +++ b/hls4ml/writer/quartus_writer.py @@ -177,8 +177,9 @@ def write_project_cpp(self, model): # Insert HLS pragmas such as maximum frequency, initiation interval etc. elif '//hls-fpga-machine-learning insert cpragmas' in line: newline = line - newline += 'hls_max_concurrency(0)\n' - newline += 'hls_component_ii({})\n'.format(self.get_max_reuse_factor(model)) + if io_type == 'io_parallel': + newline += 'hls_max_concurrency(0)\n' + newline += 'hls_component_ii({})\n'.format(self.get_max_reuse_factor(model)) clock_mhz = 1000 / (model.config.get_config_value('ClockPeriod')) newline += 'hls_scheduler_target_fmax_mhz({})\n'.format(np.ceil(clock_mhz).astype(np.int)) @@ -319,8 +320,9 @@ def write_project_header(self, model): elif '//hls-fpga-machine-learning insert cpragmas' in line: newline = line - newline += 'hls_max_concurrency(0)\n' - newline += 'hls_component_ii({})\n'.format(self.get_max_reuse_factor(model)) + if io_type == 'io_parallel': + newline += 'hls_max_concurrency(0)\n' + newline += 'hls_component_ii({})\n'.format(self.get_max_reuse_factor(model)) clock_mhz = 1000 / (model.config.get_config_value('ClockPeriod')) newline += 'hls_scheduler_target_fmax_mhz({})\n'.format(np.ceil(clock_mhz).astype(np.int)) @@ -367,9 +369,14 @@ def write_defines(self, model): all_precision = OrderedDict() for layer in model.get_layers(): layer_precision = layer.get_layer_precision() - all_precision.update(layer_precision) + for type_name, type_var in layer_precision.items(): + # Ensure that layer's types doesn't override existing types + # This can happen in case of InplaceVariable types + if type_name not in all_precision: + all_precision[type_name] = type_var for used_type in all_precision.values(): newline += used_type.definition_cpp() + else: newline = line fout.write(newline) diff --git a/test/pytest/test_cnn_mnist.py b/test/pytest/test_cnn_mnist.py index 16a7da16f7..3a08d7cfed 100644 --- a/test/pytest/test_cnn_mnist.py +++ b/test/pytest/test_cnn_mnist.py @@ -47,6 +47,7 @@ def keras_model(mnist_data): @pytest.mark.parametrize('backend,io_type,strategy', [ ('Quartus', 'io_parallel', 'resource'), + ('Quartus', 'io_stream', 'resource'), ('Vivado', 'io_parallel', 'resource'), ('Vivado', 'io_parallel', 'latency'), diff --git a/test/pytest/test_cnn_mnist_qkeras.py b/test/pytest/test_cnn_mnist_qkeras.py index 0edfd16f9f..c34e0965a6 100644 --- a/test/pytest/test_cnn_mnist_qkeras.py +++ b/test/pytest/test_cnn_mnist_qkeras.py @@ -35,10 +35,10 @@ def mnist_model(): @pytest.fixture @pytest.mark.parametrize('backend,io_type,strategy', [ ('Quartus', 'io_parallel', 'resource'), + ('Quartus', 'io_stream', 'resource'), + ('Vivado', 'io_parallel', 'resource'), - ('Vivado', 'io_parallel', 'latency'), - ('Vivado', 'io_stream', 'latency'), ('Vivado', 'io_stream', 'resource') ]) @@ -61,10 +61,10 @@ def hls_model(mnist_model, backend, io_type, strategy): @pytest.mark.parametrize('backend,io_type,strategy', [ ('Quartus', 'io_parallel', 'resource'), + ('Quartus', 'io_stream', 'resource'), + ('Vivado', 'io_parallel', 'resource'), - ('Vivado', 'io_parallel', 'latency'), - ('Vivado', 'io_stream', 'latency'), ('Vivado', 'io_stream', 'resource') ]) diff --git a/test/pytest/test_conv1d.py b/test/pytest/test_conv1d.py index bef486cda5..1d91d80ea3 100644 --- a/test/pytest/test_conv1d.py +++ b/test/pytest/test_conv1d.py @@ -23,12 +23,12 @@ def keras_model(): return model @pytest.fixture -@pytest.mark.parametrize('backend, io_type, strategy', [ +@pytest.mark.parametrize('backend,io_type,strategy', [ ('Quartus', 'io_parallel', 'resource'), + ('Quartus', 'io_stream', 'resource'), + ('Vivado', 'io_parallel', 'resource'), - ('Vivado', 'io_parallel', 'latency'), - ('Vivado', 'io_stream', 'latency'), ('Vivado', 'io_stream', 'resource') ]) @@ -56,12 +56,12 @@ def hls_model(keras_model, backend, io_type, strategy): hls_model.compile() return hls_model -@pytest.mark.parametrize('backend, io_type, strategy', [ +@pytest.mark.parametrize('backend,io_type,strategy', [ ('Quartus', 'io_parallel', 'resource'), + ('Quartus', 'io_stream', 'resource'), + ('Vivado', 'io_parallel', 'resource'), - ('Vivado', 'io_parallel', 'latency'), - ('Vivado', 'io_stream', 'latency'), ('Vivado', 'io_stream', 'resource') ]) diff --git a/test/pytest/test_globalpooling.py b/test/pytest/test_globalpooling.py index 39de850a4d..79260afbdf 100644 --- a/test/pytest/test_globalpooling.py +++ b/test/pytest/test_globalpooling.py @@ -29,14 +29,10 @@ def keras_model_avg_1d(): model.compile() return model -@pytest.mark.parametrize('backend, io_type', [ - ('Vivado', 'io_parallel'), - ('Vivado','io_stream'), - # TODO - Quartus Streaming Global Pooling - ('Quartus', 'io_parallel'), - ]) +@pytest.mark.parametrize('backend', ['Quartus', 'Vivado']) @pytest.mark.parametrize('model_type', ['max', 'avg']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_global_pool1d(backend, keras_model_max_1d, keras_model_avg_1d, data_1d, model_type, io_type): if model_type == 'avg': model = keras_model_avg_1d @@ -74,10 +70,9 @@ def keras_model_avg_2d(): model.compile() return model -# TODO - Add Streaming 2D Pooling in Vivado & Quartus @pytest.mark.parametrize('backend', ['Quartus', 'Vivado']) @pytest.mark.parametrize('model_type', ['max', 'avg']) -@pytest.mark.parametrize('io_type', ['io_parallel']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_global_pool2d(backend, keras_model_max_2d, keras_model_avg_2d, data_2d, model_type, io_type): if model_type == 'avg': diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index c6816eb4a2..bd3f175b18 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -95,7 +95,8 @@ def test_activations(activation_function, backend, io_type): padds_options = ['same', 'valid'] @pytest.mark.parametrize('padds', padds_options) @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) -def test_conv1d(padds, backend): +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv1d(padds, backend, io_type): model = tf.keras.models.Sequential() input_shape = (10, 128, 4) model.add(Conv1D(filters=32, @@ -114,45 +115,48 @@ def test_conv1d(padds, backend): keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_conv1d_{}_{}'.format(padds, backend)) - hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) + output_dir = str(test_root_path / 'hls4mlprj_keras_api_conv1d_{}_{}_{}'.format(padds, backend, io_type)) + hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) # 5e-2 might be too high np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=5e-2) - assert len(model.layers) + 2 == len(hls_model.get_layers()) - assert list(hls_model.get_layers())[1].attributes['name'] == model.layers[0]._name - assert list(hls_model.get_layers())[1].attributes['class_name'] == 'Conv1D' - assert list(hls_model.get_layers())[1].attributes['activation'] == str(model.layers[0].activation).split()[1] - assert list(hls_model.get_layers())[1].attributes["in_width"] == model.layers[0]._batch_input_shape[1] - assert list(hls_model.get_layers())[1].attributes['filt_width'] == model.layers[0].kernel_size[0] - assert list(hls_model.get_layers())[1].attributes['n_chan'] == model.layers[0].input_shape[2] - assert list(hls_model.get_layers())[1].attributes['n_filt'] == model.layers[0].filters - assert list(hls_model.get_layers())[1].attributes['stride_width'] == model.layers[0].strides[0] - assert list(hls_model.get_layers())[1].attributes['padding'] == model.layers[0].padding - assert list(hls_model.get_layers())[1].attributes['data_format'] == model.layers[0].data_format - assert list(hls_model.get_layers())[1].attributes["out_width"] == list(model.layers[0].output_shape)[1] - - out_width = math.ceil(float(model.layers[0]._batch_input_shape[2]) / float(model.layers[0].strides[0])) - pad_along_width = max((out_width - 1) * model.layers[0].strides[0] + model.layers[0].kernel_size[0] - model.layers[0]._batch_input_shape[2], 0) - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - - if model.layers[0].padding == 'same': - assert list(hls_model.get_layers())[1].attributes['pad_left'] == pad_left - assert list(hls_model.get_layers())[1].attributes['pad_right'] == pad_right - elif model.layers[0].padding == 'valid': - assert list(hls_model.get_layers())[1].attributes['pad_left'] == 0 - assert list(hls_model.get_layers())[1].attributes['pad_right'] == 0 + if not (backend=='Vivado' and io_type=='io_stream' and padds=='same'): + # Vivado inserts and additional layer for 'same' padding in io_stream + assert len(model.layers) + 2 == len(hls_model.get_layers()) + assert list(hls_model.get_layers())[1].attributes['name'] == model.layers[0]._name + assert list(hls_model.get_layers())[1].attributes['class_name'] == 'Conv1D' + assert list(hls_model.get_layers())[1].attributes['activation'] == str(model.layers[0].activation).split()[1] + assert list(hls_model.get_layers())[1].attributes["in_width"] == model.layers[0]._batch_input_shape[1] + assert list(hls_model.get_layers())[1].attributes['filt_width'] == model.layers[0].kernel_size[0] + assert list(hls_model.get_layers())[1].attributes['n_chan'] == model.layers[0].input_shape[2] + assert list(hls_model.get_layers())[1].attributes['n_filt'] == model.layers[0].filters + assert list(hls_model.get_layers())[1].attributes['stride_width'] == model.layers[0].strides[0] + assert list(hls_model.get_layers())[1].attributes['padding'] == model.layers[0].padding + assert list(hls_model.get_layers())[1].attributes['data_format'] == model.layers[0].data_format + assert list(hls_model.get_layers())[1].attributes["out_width"] == list(model.layers[0].output_shape)[1] + + out_width = math.ceil(float(model.layers[0]._batch_input_shape[2]) / float(model.layers[0].strides[0])) + pad_along_width = max((out_width - 1) * model.layers[0].strides[0] + model.layers[0].kernel_size[0] - model.layers[0]._batch_input_shape[2], 0) + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + + if model.layers[0].padding == 'same': + assert list(hls_model.get_layers())[1].attributes['pad_left'] == pad_left + assert list(hls_model.get_layers())[1].attributes['pad_right'] == pad_right + elif model.layers[0].padding == 'valid': + assert list(hls_model.get_layers())[1].attributes['pad_left'] == 0 + assert list(hls_model.get_layers())[1].attributes['pad_right'] == 0 chans_options=['channels_last'] padds_options=['same', 'valid'] @pytest.mark.parametrize('chans', chans_options) @pytest.mark.parametrize('padds', padds_options) @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) -def test_conv2d(chans, padds, backend): +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_conv2d(chans, padds, backend, io_type): model = tf.keras.models.Sequential() input_shape = (28,28,3) model.add(Conv2D(filters=32, @@ -169,8 +173,8 @@ def test_conv2d(chans, padds, backend): keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_conv2d_{}_{}_{}'.format(backend, chans, padds)) - hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) + output_dir = str(test_root_path / 'hls4mlprj_keras_api_conv2d_{}_{}_{}_{}'.format(backend, chans, padds, io_type)) + hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape) diff --git a/test/pytest/test_transpose_concat.py b/test/pytest/test_transpose_concat.py index d9f6e8217c..70e2614885 100644 --- a/test/pytest/test_transpose_concat.py +++ b/test/pytest/test_transpose_concat.py @@ -1,9 +1,8 @@ import pytest import hls4ml import numpy as np -from tensorflow.keras.models import model_from_json, Model +from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Permute, Concatenate, Activation -import yaml @pytest.fixture(scope='module') def data(): @@ -21,9 +20,9 @@ def keras_model(): return model @pytest.fixture -@pytest.mark.parametrize('io_type', ['io_parallel', - 'io_stream']) -def hls_model(keras_model, io_type): +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def hls_model(keras_model, backend, io_type): hls_config = hls4ml.utils.config_from_keras_model(keras_model, default_precision='ap_fixed<16,3,AP_RND_CONV,AP_SAT>', granularity='name') @@ -31,14 +30,14 @@ def hls_model(keras_model, io_type): hls_model = hls4ml.converters.convert_from_keras_model(keras_model, hls_config=hls_config, io_type=io_type, - output_dir='hls4mlprj_transpose_{}'.format(io_type)) + backend=backend, + output_dir='hls4mlprj_transpose_{}_{}'.format(backend, io_type)) hls_model.compile() return hls_model -# TODO - Add Quartus test after merging PR #634 -@pytest.mark.parametrize('io_type', ['io_parallel', - 'io_stream']) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) def test_accuracy(data, keras_model, hls_model): X = data model = keras_model diff --git a/test/pytest/test_upsampling.py b/test/pytest/test_upsampling.py index 764fe311d4..7e698fd907 100644 --- a/test/pytest/test_upsampling.py +++ b/test/pytest/test_upsampling.py @@ -40,12 +40,8 @@ def keras_model_2d(): return model -@pytest.mark.parametrize('backend, io_type', [ - ('Quartus', 'io_parallel'), - - ('Vivado', 'io_parallel'), - ('Vivado','io_stream') - ]) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('model_type', ['1d', '2d']) def test_upsampling(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type, io_type, backend): if model_type == '1d': diff --git a/test/pytest/test_zeropadding.py b/test/pytest/test_zeropadding.py index 841dc7ebd8..219f727c06 100644 --- a/test/pytest/test_zeropadding.py +++ b/test/pytest/test_zeropadding.py @@ -44,12 +44,8 @@ def keras_model_2d(): return model -@pytest.mark.parametrize('backend, io_type', [ - ('Quartus', 'io_parallel'), - - ('Vivado', 'io_parallel'), - ('Vivado','io_stream') - ]) +@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('model_type', ['1d', '2d']) def test_zeropadding(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type, io_type, backend): if model_type == '1d':