Skip to content

Add RNN support for Pytorch #850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion hls4ml/backends/quartus/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;

static const unsigned reuse_factor = {reuse};
static const unsigned pytorch_order = {pytorch};
static const bool store_weights_in_bram = false;
}};\n'''

Expand All @@ -92,6 +93,7 @@ def format(self, node):
params['config_mult_h'] = f'config{node.index}_h_mult'
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
gru_config = self.gru_template.format(**params)

# Activation is on candidate hidden state, dimensionality (1, n_units)
Expand Down Expand Up @@ -256,6 +258,9 @@ def format(self, node):
}};\n"""

simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
simple_rnn_pytorch_function_template = (
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
)


class SimpleRNNConfigTemplate(LayerConfigTemplate):
Expand Down Expand Up @@ -301,5 +306,9 @@ def __init__(self):

def format(self, node):
params = self._default_function_params(node)
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
if node.get_attr('pytorch', False):
self.template = simple_rnn_pytorch_function_template
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
else:
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
return self.template.format(**params)
2 changes: 2 additions & 0 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static const bool use_static = {static};
static const bool pytorch_order = {pytorch};
}};\n"""

recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
Expand Down Expand Up @@ -97,6 +98,7 @@ def format(self, node):
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
params['strategy'] = node.get_attr('strategy')
params['static'] = 'true' if node.attributes['static'] else 'false'
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
params['recr_type'] = node.class_name.lower()
params['RECR_TYPE'] = node.class_name

Expand Down
74 changes: 74 additions & 0 deletions hls4ml/converters/pytorch/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import warnings

import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler

rnn_layers = ['RNN', 'LSTM', 'GRU']


@pytorch_handler(*rnn_layers)
def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation in rnn_layers

layer = {}

layer["name"] = layer_name

layer['inputs'] = [input_names[0]]
if len(input_names) > 1:
warnings.warn(
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
stacklevel=2,
)
layer['class_name'] = operation
if operation == "RNN":
layer['class_name'] = 'SimpleRNN'

layer['return_sequences'] = False # parameter does not exist in pytorch
layer['return_state'] = False # parameter does not exist in pytorch

if layer['class_name'] == 'SimpleRNN':
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
else:
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch

if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch

layer['time_major'] = not class_object.batch_first
# TODO Should we handle time_major?
if layer['time_major']:
raise Exception('hls4ml only supports "batch-first == True"')

layer['n_timesteps'] = input_shapes[0][1]
layer['n_in'] = input_shapes[0][2]

layer['n_out'] = class_object.hidden_size

if class_object.num_layers > 1:
raise Exception('hls4ml does not support num_layers > 1')

if class_object.bidirectional:
raise Exception('hls4ml does not support birectional RNNs')

if class_object.dropout > 0:
raise Exception('hls4ml does not support RNNs with dropout')

layer['weight_data'] = class_object.weight_ih_l0.data.numpy()
layer['recurrent_weight_data'] = class_object.weight_hh_l0.data.numpy()
layer['bias_data'] = class_object.bias_ih_l0.data.numpy()
layer['recurrent_bias_data'] = class_object.bias_hh_l0.data.numpy()

if class_object.bias is False:
layer['bias_data'] = np.zeros(layer['weight_data'].shape[0])
layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0])

if layer['class_name'] == 'GRU':
layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter

output_shape = [input_shapes[0][0], layer['n_out']]

layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations

return layer, output_shape
19 changes: 17 additions & 2 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,21 @@ def pytorch_to_hls(config):

# parse info from class object
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
input_shapes = [output_shapes[str(i)] for i in node.args]

if pytorch_class in ["RNN", "GRU", "LSTM"]:
# we currently don't support the passing of the initial value of the hidden state to RNN models
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
input_shapes = [output_shapes[str(node.args[0])]]
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
elif "getitem" in node.args[0].name:
for tmp_node in traced_model.graph.nodes:
if tmp_node.name == node.args[0].name:
if "getitem" in tmp_node.args[0].name:
raise Exception('Nested getitem calles not resolved at the moment.')
input_names = [inputs_map.get(str(tmp_node.args[0]), str(tmp_node.args[0]))]
input_shapes = [output_shapes[str(tmp_node.args[0])]]
node.args = [tmp_node.args[0]]
else:
input_shapes = [output_shapes[str(i)] for i in node.args]
# for Conv layers
if 'Conv' in pytorch_class:
if not class_object.padding_mode == 'zeros':
Expand Down Expand Up @@ -254,6 +267,8 @@ def pytorch_to_hls(config):
operation = layer_name_map[operation]

# only a limited number of functions are supported
if operation == "getitem":
continue
if operation not in supported_layers:
raise Exception(f'Unsupported function {operation}')
if operation == 'PReLU' or operation == 'batch_norm' or operation == 'conv1d' or operation == 'conv2d':
Expand Down
9 changes: 7 additions & 2 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,8 @@ def initialize(self):

# biases
self.add_weights_variable(name='bias', var_name='b{index}')
if "pytorch" in self.attributes.keys():
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')


class LSTM(Layer):
Expand Down Expand Up @@ -1093,8 +1095,11 @@ def initialize(self):
# biases
self.add_weights_variable(name='bias', var_name='b{index}')

recurrent_bias = np.zeros(recurrent_weight.shape[1])
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
if "pytorch" in self.attributes.keys():
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')
else:
recurrent_bias = np.zeros(recurrent_weight.shape[1])
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)


class GRU(Layer):
Expand Down
6 changes: 3 additions & 3 deletions hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def match(self, node):

def transform(self, model, node):
# If this parameter has not been set, this model does not need to be converted
if 'InputsChannelLast' not in model.config.config['HLSConfig']['Model']:
if 'ChannelsLastConversion' not in model.config.config['HLSConfig']['Model']:
node.channels_last_converted = True
return False
outshape = node.get_output_variable().shape

if isinstance(node, Input):
# if inputs are not yet transposed into channels_last, add transpose layer
if not model.config.config['HLSConfig']['Model']['InputsChannelLast'] and len(outshape) > 1:
if model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "full" and len(outshape) > 1:
# Add transpose for input layer
input = node.name
if len(outshape) == 2:
Expand All @@ -39,7 +39,7 @@ def transform(self, model, node):
transpose_node.channels_last_converted = True

model.insert_node(transpose_node)
else:
elif model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal" and len(outshape) > 1:
input_shape = node.get_output_variable().shape
input_shape.append(input_shape.pop(0))
node.get_output_variable().shape = input_shape
Expand Down
137 changes: 135 additions & 2 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct gru_config {
// Resource reuse info
static const unsigned io_type = io_parallel;
static const unsigned reuse_factor = 1;
static const bool pytorch_order = false;
static const bool store_weights_in_bram = false;

// Activation
Expand Down Expand Up @@ -133,7 +134,10 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
hls_register typename CONFIG_T::accum_t hadamard_r_h[CONFIG_T::n_units];
#pragma unroll recurrent_unroll_factor
for (int i = 0; i < (CONFIG_T::n_units); i++) {
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
if (CONFIG_T::pytorch_order)
hadamard_r_h[i] = z_r_act[i] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
else
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
}

// The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h}
Expand All @@ -152,7 +156,11 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
// Update state
#pragma unroll recurrent_unroll_factor
for (int i = 0; i < (CONFIG_T::n_units); i++) {
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
if (CONFIG_T::pytorch_order)
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i + CONFIG_T::n_units]) +
h[i] * z_r_act[i + CONFIG_T::n_units]);
else
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
}
}

Expand Down Expand Up @@ -315,6 +323,131 @@ void simple_rnn(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[C
}
}
}
//----------------------
// SimpleRNN with pytorch biases
//----------------------

struct simpleRNN_pytorch_config {
// Internal data type definitions
typedef float weight_t;
typedef float bias_t;
typedef float accum_t;

// Layer Sizes
static const unsigned n_in = 1;
static const unsigned n_out = 1;
static const unsigned n_outputs = 1;
static const unsigned n_timesteps = 1;
static const bool return_sequences = false;

// Resource reuse info
static const unsigned io_type = io_parallel;
static const unsigned reuse_factor = 1;
static const bool store_weights_in_bram = false;

// Activation
template <class x_T, class y_T, class config_T> using activation_recr = nnet::activation::relu<x_T, y_T, config_T>;

template <class x_T, class y_T, class config_T> using activation = nnet::activation::relu<x_T, y_T, config_T>;
};

template <class data_T, class res_T, typename CONFIG_T>
void simple_rnn_pytorch_cell(data_T inputs[CONFIG_T::n_in], res_T hidden_state[CONFIG_T::n_out],
res_T hidden_state_o[CONFIG_T::n_out],
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
// Weight multiplication
typename CONFIG_T::accum_t afterW[CONFIG_T::n_out] hls_register;
multiply_W<data_T, typename CONFIG_T::accum_t, typename CONFIG_T::weight_t, CONFIG_T::n_in, CONFIG_T::n_out>(
inputs, afterW, kernel);

// Bias addition
typename CONFIG_T::accum_t afterBias[CONFIG_T::n_out] hls_register;
add_bias<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::bias_t, CONFIG_T::n_out>(
afterW, afterBias, bias);

// Hidden state
typename CONFIG_T::accum_t hiddenCand[CONFIG_T::n_out] hls_register;
multiply_U<data_T, typename CONFIG_T::accum_t, typename CONFIG_T::weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
rec_kernel);

// Hidden state bias addition
typename CONFIG_T::accum_t hiddenBias[CONFIG_T::n_out] hls_register;
add_bias<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::bias_t, CONFIG_T::n_out>(
hiddenCand, hiddenBias, rec_bias);

// Vector addition
typename CONFIG_T::accum_t afterAdd[CONFIG_T::n_out];
add_vectors<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, CONFIG_T::n_out>(afterBias, hiddenBias, afterAdd);

// Activation
CONFIG_T::template activation<typename CONFIG_T::accum_t, data_T, typename CONFIG_T::ACT_CONFIG_T>::activation(
afterAdd, hidden_state_o);
}

template <class data_T, class res_T, typename CONFIG_T>
void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in],
res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
res_T h[CONFIG_T::n_out] hls_register;
data_T in[CONFIG_T::n_in] hls_register;

// Set initially hidden state (output) to zero
INIT_LOOP:
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[x][0] = 0;
}

#pragma disable_loop_pipelining
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {

// Data at current time step
#pragma unroll
for (int x = 0; x < CONFIG_T::n_in; x++) {
in[x] = data[x + i * CONFIG_T::n_in];
}

// Hidden state at current time step
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state_temp[x] = hidden_state[x][i];
}

// Do SimpleRNN
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);

// Write result
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[x][i + 1] = h[x];
}
}

if (CONFIG_T::return_sequences == 0) {
// Output when return_sequences is false
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
}
} else {
// Output when return_sequences is true
#pragma unroll
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
#pragma unroll
for (int h = 0; h < CONFIG_T::n_out; h++) {
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
}
}
}
}

//----------------------
// LSTM
Expand Down
Loading
Loading