-
Notifications
You must be signed in to change notification settings - Fork 463
Vivado Backend GRU/LSTM support #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e780073
working version of gru/lstm for vivado backend
dsrankin c35b4cc
removing unused mask param
dsrankin 3143171
fixing transpose issues
dsrankin f183ba9
hls cleanup
dsrankin 08e3b69
hls cleanup
dsrankin 367d1f0
moving static to extended attributes in backend
dsrankin 03823c4
changing to return just vector in case return_sequences=False
dsrankin cb76f4e
cleaning up transpose for resource strategy
dsrankin 0a5d209
combining templates
dsrankin 5bddddf
adding comment
dsrankin 91ed034
fixing SimpleRNN shape issue
dsrankin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
|
||
from hls4ml.backends.backend import get_backend | ||
from hls4ml.model.layers import LSTM, GRU | ||
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate | ||
|
||
# recurrent multiplication template | ||
|
||
recr_mult_config_template = """struct config{index} : nnet::dense_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned strategy = nnet::{strategy}; | ||
static const unsigned reuse_factor = {reuse}; | ||
static const unsigned n_zeros = {nzeros}; | ||
static const unsigned n_nonzeros = {nonzeros}; | ||
static const bool store_weights_in_bram = false; | ||
typedef {accum_t.name} accum_t; | ||
typedef {bias_t.name} bias_t; | ||
typedef {weight_t.name} weight_t; | ||
typedef ap_{index_t} index_t; | ||
template<class x_T, class y_T, class res_T> | ||
using product = nnet::product::{product_type}<x_T, y_T, res_T>; | ||
}};\n""" | ||
|
||
#activation templates | ||
|
||
activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned table_size = {table_size}; | ||
static const unsigned io_type = nnet::{iotype}; | ||
static const unsigned reuse_factor = {reuse}; | ||
typedef ap_{table_t} table_t; | ||
}};\n""" | ||
|
||
recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned table_size = {table_size}; | ||
static const unsigned io_type = nnet::{iotype}; | ||
static const unsigned reuse_factor = {reuse}; | ||
typedef ap_{table_t} table_t; | ||
}};\n""" | ||
|
||
# LSTM + GRU templates | ||
|
||
recr_config_template = """struct config{index} : nnet::{recr_type}_config {{ | ||
typedef {accum_t.name} accum_t; | ||
typedef {weight_t.name} weight_t; // Matrix | ||
typedef {bias_t.name} bias_t; // Vector | ||
typedef {config_mult_t1} mult_config1; | ||
typedef {config_mult_t2} mult_config2; | ||
typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE}; | ||
template<class x_T, class y_T, class config_T> | ||
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>; | ||
typedef {act_t} ACT_CONFIG_T; | ||
template<class x_T, class y_T, class config_T> | ||
using activation = nnet::activation::{activation}<x_T, y_T, config_T>; | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned n_state = {n_state}; | ||
static const unsigned n_sequence = {n_sequence}; | ||
static const unsigned n_sequence_out = {n_sequence_out}; | ||
static const unsigned io_type = nnet::{strategy}; | ||
static const unsigned reuse_factor = {reuse}; | ||
static const bool store_weights_in_bram = false; | ||
static const bool use_static = {static}; | ||
}};\n""" | ||
|
||
recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' | ||
|
||
recr_include_list = ['nnet_utils/nnet_recurrent.h'] | ||
|
||
class RecurrentConfigTemplate(LayerConfigTemplate): | ||
def __init__(self): | ||
super().__init__((LSTM, GRU)) | ||
self.template = recr_config_template | ||
self.act_template = activ_config_template | ||
self.recr_act_template = recr_activ_config_template | ||
self.mult1_template = recr_mult_config_template | ||
self.mult2_template = recr_mult_config_template | ||
|
||
def format(self, node): | ||
|
||
params = self._default_config_params(node) | ||
|
||
params['n_in'] = node.get_input_variable().dim_names[1] | ||
params['n_sequence'] = node.get_input_variable().dim_names[0] | ||
if node.get_attr('return_sequences'): | ||
params['n_sequence_out'] = node.get_output_variable().dim_names[0] | ||
params['n_state'] = node.get_output_variable().dim_names[1] | ||
params['n_out'] = node.get_output_variable().dim_names[1] | ||
else: | ||
params['n_sequence_out'] = 1 | ||
params['n_state'] = node.get_output_variable().dim_names[0] | ||
params['n_out'] = node.get_output_variable().dim_names[0] | ||
params['config_mult_t1'] = 'config{}_1'.format(node.index) | ||
params['config_mult_t2'] = 'config{}_2'.format(node.index) | ||
params['recr_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index) | ||
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index) | ||
params['strategy'] = node.get_attr('strategy') | ||
params['static'] = 'true' if node.attributes['static'] else 'false' | ||
params['recr_type'] = node.class_name.lower() | ||
params['RECR_TYPE'] = node.class_name | ||
|
||
if node.class_name=='LSTM': | ||
n_recr_mult = 4 | ||
else: #GRU | ||
n_recr_mult = 3 | ||
|
||
recr_config = self.template.format(**params) | ||
|
||
act_params = self._default_config_params(node) | ||
recr_act_params = self._default_config_params(node) | ||
|
||
act_params['type'] = node.get_attr('activation') | ||
recr_act_params['type'] = node.get_attr('recurrent_activation') | ||
if node.get_attr('return_sequences'): | ||
act_params['n_in'] = node.get_output_variable().dim_names[1] | ||
recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * %i'%(n_recr_mult-1) | ||
else: | ||
act_params['n_in'] = node.get_output_variable().dim_names[0] | ||
recr_act_params['n_in'] = node.get_output_variable().dim_names[0] + ' * %i'%(n_recr_mult-1) | ||
|
||
act_config = self.act_template.format(**act_params) | ||
recr_act_config = self.recr_act_template.format(**recr_act_params) | ||
|
||
mult_params1 = self._default_config_params(node) | ||
mult_params2 = self._default_config_params(node) | ||
|
||
mult_params1['n_in'] = node.get_input_variable().dim_names[1] | ||
if node.get_attr('return_sequences'): | ||
mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult | ||
else: | ||
mult_params1['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult | ||
mult_params1['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) | ||
mult_params1['reuse'] = params['reuse'] | ||
mult_params1['index'] = str(node.index) + '_1' | ||
mult_params1['nzeros'] = node.get_weights('weight').nzeros | ||
mult_params1['nonzeros'] = node.get_weights('weight').nonzeros | ||
if node.get_attr('return_sequences'): | ||
mult_params2['n_in'] = node.get_output_variable().dim_names[1] | ||
mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i'%n_recr_mult | ||
else: | ||
mult_params2['n_in'] = node.get_output_variable().dim_names[0] | ||
mult_params2['n_out'] = node.get_output_variable().dim_names[0] + ' * %i'%n_recr_mult | ||
mult_params2['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision) | ||
mult_params2['reuse'] = node.attributes['recurrent_reuse_factor'] | ||
mult_params2['index'] = str(node.index) + '_2' | ||
mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros | ||
mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros | ||
|
||
mult_config1 = self.mult1_template.format(**mult_params1) | ||
mult_config2 = self.mult2_template.format(**mult_params2) | ||
|
||
return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config | ||
|
||
class RecurrentFunctionTemplate(FunctionCallTemplate): | ||
def __init__(self): | ||
super().__init__((LSTM, GRU), include_header=recr_include_list) | ||
self.template = recr_function_template | ||
|
||
def format(self, node): | ||
params = self._default_function_params(node) | ||
params['w'] = node.get_weights('weight').name | ||
params['b'] = node.get_weights('bias').name | ||
params['wr'] = node.get_weights('recurrent_weight').name | ||
params['br'] = node.get_weights('recurrent_bias').name | ||
params['activation'] = node.get_attr('activation') | ||
params['recurrent_activation'] = node.get_attr('recurrent_activation') | ||
params['recr_type'] = node.class_name.lower() | ||
|
||
return self.template.format(**params) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
hls4ml/templates/vivado/nnet_utils/nnet_recr_activations.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#ifndef NNET_RECR_ACTIVATION_H_ | ||
#define NNET_RECR_ACTIVATION_H_ | ||
|
||
#include "nnet_common.h" | ||
#include "nnet_helpers.h" | ||
#include "nnet_activation.h" | ||
#include "hls_stream.h" | ||
#include <math.h> | ||
|
||
namespace nnet { | ||
|
||
namespace activation{ | ||
|
||
template<class data_T, class res_T, typename CONFIG_T> | ||
class Activation{ | ||
public: | ||
// ************************************************* | ||
// Blank Activation | ||
// ************************************************* | ||
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {} // Nothing to do here | ||
}; | ||
|
||
template<class data_T, class res_T, typename CONFIG_T> | ||
class relu : public Activation<data_T, res_T, CONFIG_T>{ | ||
public: | ||
// ************************************************* | ||
// Relu Activation | ||
// ************************************************* | ||
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) | ||
{ | ||
nnet::relu<data_T, res_T, CONFIG_T>(data, res); | ||
} | ||
}; | ||
|
||
template<class data_T, class res_T, typename CONFIG_T> | ||
class sigmoid : public Activation<data_T, res_T, CONFIG_T>{ | ||
public: | ||
// ************************************************* | ||
// Sigmoid Activation | ||
// ************************************************* | ||
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) | ||
{ | ||
nnet::sigmoid<data_T, res_T, CONFIG_T>(data, res); | ||
} | ||
}; | ||
|
||
template<class data_T, class res_T, typename CONFIG_T> | ||
class tanh : public Activation<data_T, res_T, CONFIG_T>{ | ||
public: | ||
// ************************************************* | ||
// TanH Activation | ||
// ************************************************* | ||
static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) | ||
{ | ||
nnet::tanh<data_T, res_T, CONFIG_T>(data, res); | ||
} | ||
}; | ||
|
||
} | ||
|
||
} | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.