diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index 598c1c59d5..74ec61e823 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -17,8 +17,8 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef ap_{index_t} index_t; - template - using product = nnet::product::{product_type}; + template + using product = nnet::product::{product_type}; }};\n""" #activation templates diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 2b910c4cd1..3ca05f3e6c 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -241,9 +241,6 @@ def init_lstm(self, layer): reuse_factor = layer.model.config.get_reuse_factor(layer) layer.set_attr('recurrent_reuse_factor', reuse_factor) - recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1]) - layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias) - index_t = IntegerPrecisionType(width=1, signed=False) if 'table_t' not in layer.attributes: @@ -267,9 +264,6 @@ def init_gru(self, layer): reuse_factor = layer.model.config.get_reuse_factor(layer) layer.set_attr('recurrent_reuse_factor', reuse_factor) - recurrent_bias = np.zeros(layer.weights['recurrent_weight'].shape[1]) - layer.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias) - index_t = IntegerPrecisionType(width=1, signed=False) if 'table_t' not in layer.attributes: diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f821d08e90..8670889e46 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -885,12 +885,17 @@ def initialize(self): self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t') self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t') + #weights self.add_weights() - self.add_bias() + #recurrent weights recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel') self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) + #biases + biases = self.model.get_weights_data(self.name , 'bias') + self.add_weights_variable(name='bias', var_name='b{index}', data=biases) + class LSTM(Layer): _expected_attributes = [ Attribute('n_out'), @@ -904,10 +909,12 @@ class LSTM(Layer): WeightAttribute('weight'), WeightAttribute('bias'), WeightAttribute('recurrent_weight'), + WeightAttribute('recurrent_bias'), TypeAttribute('weight'), TypeAttribute('bias'), TypeAttribute('recurrent_weight'), + TypeAttribute('recurrent_bias'), ] def initialize(self): @@ -926,12 +933,20 @@ def initialize(self): self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t') self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t') + #weights self.add_weights() - self.add_bias() + #recurrent weights recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel') self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) + #biases + biases = self.model.get_weights_data(self.name , 'bias') + self.add_weights_variable(name='bias', var_name='b{index}', data=biases) + + recurrent_bias = np.zeros(recurrent_weight.shape[1]) + self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias) + class GRU(Layer): _expected_attributes = [ Attribute('n_out'), @@ -946,10 +961,12 @@ class GRU(Layer): WeightAttribute('weight'), WeightAttribute('bias'), WeightAttribute('recurrent_weight'), + WeightAttribute('recurrent_bias'), TypeAttribute('weight'), TypeAttribute('bias'), TypeAttribute('recurrent_weight'), + TypeAttribute('recurrent_bias'), ] def initialize(self): @@ -968,12 +985,19 @@ def initialize(self): self.add_output_variable(state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t') self.add_output_variable(state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t') + #weights self.add_weights() - self.add_bias() + #recurrent weights recurrent_weight = self.model.get_weights_data(self.name, 'recurrent_kernel') self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) + #biases array is actually a 2-dim array of arrays (bias + recurrent bias) + #both arrays have shape: n_units * 3 (z, r, h_cand) + biases = self.model.get_weights_data(self.name , 'bias') + self.add_weights_variable(name='bias', var_name='b{index}', data=biases[0]) + self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=biases[1]) + class GarNet(Layer): ref_impl = False diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h index a7096dde18..e94286aa8e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -288,7 +288,7 @@ template data_in[i_pack] = data_pack[i_pack]; } if (CONFIG_T::use_static) - nnet::lstm_static(reset_state,data_in,h_newstate, param,param_r,param_b, param_br); + nnet::lstm_static(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br); else nnet::lstm(reset_state,data_in,h_newstate, s_newstate, param,param_r,param_b, param_br); if (CONFIG_T::n_sequence_out > 1){ diff --git a/test/pytest/test_rnn.py b/test/pytest/test_rnn.py index 99695fe22d..bc7ecc7aa1 100644 --- a/test/pytest/test_rnn.py +++ b/test/pytest/test_rnn.py @@ -1,13 +1,11 @@ import pytest import hls4ml -import tensorflow as tf import numpy as np from pathlib import Path -from tensorflow.keras import optimizers -from tensorflow.keras.models import Model -from tensorflow.keras.layers import Input, Embedding, SimpleRNN, LSTM, GRU import math from tensorflow.keras import backend as K +from tensorflow.keras.models import Model, Sequential +from tensorflow.keras.layers import Input, SimpleRNN, LSTM, GRU test_root_path = Path(__file__).parent @@ -46,13 +44,62 @@ def test_rnn_parsing(rnn_layer, return_sequences): assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size # Compare weights - hls_weights = list(hls_layer.get_weights()) # [weights, bias, recurrent_weights, "recurrent_bias" hack] + hls_weights = list(hls_layer.get_weights()) # [weights, recurrent_weights, bias, recurrent_bias] rnn_weights = keras_layer.get_weights() # [weights, recurrent_weights, bias] assert hls_weights[0].data.shape == rnn_weights[0].shape - assert hls_weights[2].data.shape == rnn_weights[1].shape - assert hls_weights[1].data.shape == rnn_weights[2].shape - + assert hls_weights[1].data.shape == rnn_weights[1].shape + if 'gru' in rnn_layer.__name__.lower(): + # GRU has both bias and recurrent bias + assert hls_weights[2].data.shape == rnn_weights[2][0].shape + assert hls_weights[3].data.shape == rnn_weights[2][1].shape + else: + # LSTM and SimpleRNN only have bias + assert hls_weights[2].data.shape == rnn_weights[2].shape + np.testing.assert_array_equal(hls_weights[0].data, rnn_weights[0]) - np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[1]) - np.testing.assert_array_equal(hls_weights[1].data, rnn_weights[2]) + np.testing.assert_array_equal(hls_weights[1].data, rnn_weights[1]) + if 'gru' in rnn_layer.__name__.lower(): + np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2][0]) + np.testing.assert_array_equal(hls_weights[3].data, rnn_weights[2][1]) + else: + np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2]) + +@pytest.mark.parametrize('rnn_layer', [LSTM, GRU]) +@pytest.mark.parametrize('return_sequences', [True, False]) +@pytest.mark.parametrize('backend', ['Vivado']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('static', [True, False]) +def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, static): + # Subtract 0.5 to include negative values + input_shape = (5, 8) + X = np.random.rand(50, *input_shape) - 0.5 + + layer_name = rnn_layer.__class__.__name__.lower() + keras_model = Sequential() + keras_model.add(rnn_layer(units=32, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform', return_sequences=return_sequences, name=layer_name)) + keras_model.compile() + + default_precision = 'ap_fixed<32, 16>' if backend == 'Vivado' else 'ac_fixed<32, 16, true>' + hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', default_precision=default_precision) + hls_config['LayerName'][layer_name]['static'] = static + output_dir = 'hls4mlprj_rnn_accuracy_{}_static_{}_ret_seq_{}_{}_{}'.format( + rnn_layer.__class__.__name__.lower(), + int(static), + int(return_sequences), + backend, + io_type + ) + + hls_model = hls4ml.converters.convert_from_keras_model( + keras_model, + hls_config=hls_config, + output_dir=output_dir, + backend=backend, + io_type=io_type + ) + hls_model.compile() + + keras_prediction = keras_model.predict(X) + hls_prediction = hls_model.predict(X) + np.testing.assert_allclose(hls_prediction.flatten(), keras_prediction.flatten(), rtol=0.0, atol=3e-2)