Skip to content

Fixes for GRU/LSTM in Vivado backend #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef ap_{index_t} index_t;
template<class x_T, class y_T, class res_T>
using product = nnet::product::{product_type}<x_T, y_T, res_T>;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

#activation templates
Expand Down
6 changes: 0 additions & 6 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,6 @@ def init_lstm(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)

if 'table_t' not in layer.attributes:
Expand All @@ -267,9 +264,6 @@ def init_gru(self, layer):
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('recurrent_reuse_factor', reuse_factor)

recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1])
layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

index_t = IntegerPrecisionType(width=1, signed=False)

if 'table_t' not in layer.attributes:
Expand Down
30 changes: 27 additions & 3 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,17 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

#weights
self.add_weights()
self.add_bias()

#recurrent weights
recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

#biases
biases = self.model.get_weights_data(self.name , 'bias')
self.add_weights_variable(name='bias', var_name='b{index}', data=biases)

class LSTM(Layer):
_expected_attributes = [
Attribute('n_out'),
Expand All @@ -904,10 +909,12 @@ class LSTM(Layer):
WeightAttribute('weight'),
WeightAttribute('bias'),
WeightAttribute('recurrent_weight'),
WeightAttribute('recurrent_bias'),

TypeAttribute('weight'),
TypeAttribute('bias'),
TypeAttribute('recurrent_weight'),
TypeAttribute('recurrent_bias'),
]

def initialize(self):
Expand All @@ -926,12 +933,20 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

#weights
self.add_weights()
self.add_bias()

#recurrent weights
recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

#biases
biases = self.model.get_weights_data(self.name , 'bias')
self.add_weights_variable(name='bias', var_name='b{index}', data=biases)

recurrent_bias = np.zeros(recurrent_weight.shape[1])
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)

class GRU(Layer):
_expected_attributes = [
Attribute('n_out'),
Expand All @@ -946,10 +961,12 @@ class GRU(Layer):
WeightAttribute('weight'),
WeightAttribute('bias'),
WeightAttribute('recurrent_weight'),
WeightAttribute('recurrent_bias'),

TypeAttribute('weight'),
TypeAttribute('bias'),
TypeAttribute('recurrent_weight'),
TypeAttribute('recurrent_bias'),
]

def initialize(self):
Expand All @@ -968,12 +985,19 @@ def initialize(self):
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t')
self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t')

#weights
self.add_weights()
self.add_bias()

#recurrent weights
recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel')
self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight)

#biases array is actually a 2-dim array of arrays (bias + recurrent bias)
#both arrays have shape: n_units * 3 (z, r, h_cand)
biases = self.model.get_weights_data(self.name , 'bias')
self.add_weights_variable(name='bias', var_name='b{index}', data=biases[0])
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=biases[1])

class GarNet(Layer):
ref_impl = False

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ template<class data_T, class res_T, typename CONFIG_T>
data_in[i_pack] = data_pack[i_pack];
}
if (CONFIG_T::use_static)
nnet::lstm_static<typename data_T::value_type, typename res_T::value_type, CONFIG_T>(reset_state,data_in,h_newstate, param,param_r,param_b, param_br);
nnet::lstm_static<typename data_T::value_type, typename res_T::value_type, CONFIG_T>(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br);
else
nnet::lstm<typename data_T::value_type, typename res_T::value_type, CONFIG_T>(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br);
if (CONFIG_T::n_sequence_out > 1){
Expand Down
67 changes: 57 additions & 10 deletions test/pytest/test_rnn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pytest
import hls4ml
import tensorflow as tf
import numpy as np
from pathlib import Path
from tensorflow.keras import optimizers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, SimpleRNN, LSTM, GRU
import math
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, SimpleRNN, LSTM, GRU

test_root_path = Path(__file__).parent

Expand Down Expand Up @@ -46,13 +44,62 @@ def test_rnn_parsing(rnn_layer, return_sequences):
assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size

# Compare weights
hls_weights = list(hls_layer.get_weights()) # [weights, bias, recurrent_weights, "recurrent_bias" hack]
hls_weights = list(hls_layer.get_weights()) # [weights, recurrent_weights, bias, recurrent_bias]
rnn_weights = keras_layer.get_weights() # [weights, recurrent_weights, bias]

assert hls_weights[0].data.shape == rnn_weights[0].shape
assert hls_weights[2].data.shape == rnn_weights[1].shape
assert hls_weights[1].data.shape == rnn_weights[2].shape

assert hls_weights[1].data.shape == rnn_weights[1].shape
if 'gru' in rnn_layer.__name__.lower():
# GRU has both bias and recurrent bias
assert hls_weights[2].data.shape == rnn_weights[2][0].shape
assert hls_weights[3].data.shape == rnn_weights[2][1].shape
else:
# LSTM and SimpleRNN only have bias
assert hls_weights[2].data.shape == rnn_weights[2].shape

np.testing.assert_array_equal(hls_weights[0].data, rnn_weights[0])
np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[1])
np.testing.assert_array_equal(hls_weights[1].data, rnn_weights[2])
np.testing.assert_array_equal(hls_weights[1].data, rnn_weights[1])
if 'gru' in rnn_layer.__name__.lower():
np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2][0])
np.testing.assert_array_equal(hls_weights[3].data, rnn_weights[2][1])
else:
np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2])

@pytest.mark.parametrize('rnn_layer', [LSTM, GRU])
@pytest.mark.parametrize('return_sequences', [True, False])
@pytest.mark.parametrize('backend', ['Vivado'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
@pytest.mark.parametrize('static', [True, False])
def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, static):
# Subtract 0.5 to include negative values
input_shape = (5, 8)
X = np.random.rand(50, *input_shape) - 0.5

layer_name = rnn_layer.__class__.__name__.lower()
keras_model = Sequential()
keras_model.add(rnn_layer(units=32, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform', return_sequences=return_sequences, name=layer_name))
keras_model.compile()

default_precision = 'ap_fixed<32, 16>' if backend == 'Vivado' else 'ac_fixed<32, 16, true>'
hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', default_precision=default_precision)
hls_config['LayerName'][layer_name]['static'] = static
output_dir = 'hls4mlprj_rnn_accuracy_{}_static_{}_ret_seq_{}_{}_{}'.format(
rnn_layer.__class__.__name__.lower(),
int(static),
int(return_sequences),
backend,
io_type
)

hls_model = hls4ml.converters.convert_from_keras_model(
keras_model,
hls_config=hls_config,
output_dir=output_dir,
backend=backend,
io_type=io_type
)
hls_model.compile()

keras_prediction = keras_model.predict(X)
hls_prediction = hls_model.predict(X)
np.testing.assert_allclose(hls_prediction.flatten(), keras_prediction.flatten(), rtol=0.0, atol=3e-2)