Skip to content

Commit 75b0b0d

Browse files
authored
Merge pull request #850 from JanFSchulte/GRUv1
Add RNN support for Pytorch
2 parents 7982c87 + 54d7a34 commit 75b0b0d

File tree

13 files changed

+414
-35
lines changed

13 files changed

+414
-35
lines changed

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
6767
6868
static const unsigned reuse_factor = {reuse};
69+
static const unsigned pytorch_order = {pytorch};
6970
static const bool store_weights_in_bram = false;
7071
}};\n'''
7172

@@ -92,6 +93,7 @@ def format(self, node):
9293
params['config_mult_h'] = f'config{node.index}_h_mult'
9394
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
9495
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
96+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
9597
gru_config = self.gru_template.format(**params)
9698

9799
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -256,6 +258,9 @@ def format(self, node):
256258
}};\n"""
257259

258260
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
261+
simple_rnn_pytorch_function_template = (
262+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
263+
)
259264

260265

261266
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -301,5 +306,9 @@ def __init__(self):
301306

302307
def format(self, node):
303308
params = self._default_function_params(node)
304-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
309+
if node.get_attr('pytorch', False):
310+
self.template = simple_rnn_pytorch_function_template
311+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312+
else:
313+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
305314
return self.template.format(**params)

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
static const unsigned reuse_factor = {reuse};
6363
static const bool store_weights_in_bram = false;
6464
static const bool use_static = {static};
65+
static const bool pytorch_order = {pytorch};
6566
}};\n"""
6667

6768
recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
@@ -97,6 +98,7 @@ def format(self, node):
9798
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
9899
params['strategy'] = node.get_attr('strategy')
99100
params['static'] = 'true' if node.attributes['static'] else 'false'
101+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
100102
params['recr_type'] = node.class_name.lower()
101103
params['RECR_TYPE'] = node.class_name
102104

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import warnings
2+
3+
import numpy as np
4+
5+
from hls4ml.converters.pytorch_to_hls import pytorch_handler
6+
7+
rnn_layers = ['RNN', 'LSTM', 'GRU']
8+
9+
10+
@pytorch_handler(*rnn_layers)
11+
def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
12+
assert operation in rnn_layers
13+
14+
layer = {}
15+
16+
layer["name"] = layer_name
17+
18+
layer['inputs'] = [input_names[0]]
19+
if len(input_names) > 1:
20+
warnings.warn(
21+
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
22+
stacklevel=2,
23+
)
24+
layer['class_name'] = operation
25+
if operation == "RNN":
26+
layer['class_name'] = 'SimpleRNN'
27+
28+
layer['return_sequences'] = False # parameter does not exist in pytorch
29+
layer['return_state'] = False # parameter does not exist in pytorch
30+
31+
if layer['class_name'] == 'SimpleRNN':
32+
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
33+
else:
34+
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch
35+
36+
if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
37+
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch
38+
39+
layer['time_major'] = not class_object.batch_first
40+
# TODO Should we handle time_major?
41+
if layer['time_major']:
42+
raise Exception('hls4ml only supports "batch-first == True"')
43+
44+
layer['n_timesteps'] = input_shapes[0][1]
45+
layer['n_in'] = input_shapes[0][2]
46+
47+
layer['n_out'] = class_object.hidden_size
48+
49+
if class_object.num_layers > 1:
50+
raise Exception('hls4ml does not support num_layers > 1')
51+
52+
if class_object.bidirectional:
53+
raise Exception('hls4ml does not support birectional RNNs')
54+
55+
if class_object.dropout > 0:
56+
raise Exception('hls4ml does not support RNNs with dropout')
57+
58+
layer['weight_data'] = class_object.weight_ih_l0.data.numpy()
59+
layer['recurrent_weight_data'] = class_object.weight_hh_l0.data.numpy()
60+
layer['bias_data'] = class_object.bias_ih_l0.data.numpy()
61+
layer['recurrent_bias_data'] = class_object.bias_hh_l0.data.numpy()
62+
63+
if class_object.bias is False:
64+
layer['bias_data'] = np.zeros(layer['weight_data'].shape[0])
65+
layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0])
66+
67+
if layer['class_name'] == 'GRU':
68+
layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter
69+
70+
output_shape = [input_shapes[0][0], layer['n_out']]
71+
72+
layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations
73+
74+
return layer, output_shape

hls4ml/converters/pytorch_to_hls.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,21 @@ def pytorch_to_hls(config):
199199

200200
# parse info from class object
201201
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
202-
input_shapes = [output_shapes[str(i)] for i in node.args]
203-
202+
if pytorch_class in ["RNN", "GRU", "LSTM"]:
203+
# we currently don't support the passing of the initial value of the hidden state to RNN models
204+
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
205+
input_shapes = [output_shapes[str(node.args[0])]]
206+
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
207+
elif "getitem" in node.args[0].name:
208+
for tmp_node in traced_model.graph.nodes:
209+
if tmp_node.name == node.args[0].name:
210+
if "getitem" in tmp_node.args[0].name:
211+
raise Exception('Nested getitem calles not resolved at the moment.')
212+
input_names = [inputs_map.get(str(tmp_node.args[0]), str(tmp_node.args[0]))]
213+
input_shapes = [output_shapes[str(tmp_node.args[0])]]
214+
node.args = [tmp_node.args[0]]
215+
else:
216+
input_shapes = [output_shapes[str(i)] for i in node.args]
204217
# for Conv layers
205218
if 'Conv' in pytorch_class:
206219
if not class_object.padding_mode == 'zeros':
@@ -254,6 +267,8 @@ def pytorch_to_hls(config):
254267
operation = layer_name_map[operation]
255268

256269
# only a limited number of functions are supported
270+
if operation == "getitem":
271+
continue
257272
if operation not in supported_layers:
258273
raise Exception(f'Unsupported function {operation}')
259274
if operation == 'PReLU' or operation == 'batch_norm' or operation == 'conv1d' or operation == 'conv2d':

hls4ml/model/layers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,8 @@ def initialize(self):
10421042

10431043
# biases
10441044
self.add_weights_variable(name='bias', var_name='b{index}')
1045+
if "pytorch" in self.attributes.keys():
1046+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')
10451047

10461048

10471049
class LSTM(Layer):
@@ -1093,8 +1095,11 @@ def initialize(self):
10931095
# biases
10941096
self.add_weights_variable(name='bias', var_name='b{index}')
10951097

1096-
recurrent_bias = np.zeros(recurrent_weight.shape[1])
1097-
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
1098+
if "pytorch" in self.attributes.keys():
1099+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')
1100+
else:
1101+
recurrent_bias = np.zeros(recurrent_weight.shape[1])
1102+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
10981103

10991104

11001105
class GRU(Layer):

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ def match(self, node):
1717

1818
def transform(self, model, node):
1919
# If this parameter has not been set, this model does not need to be converted
20-
if 'InputsChannelLast' not in model.config.config['HLSConfig']['Model']:
20+
if 'ChannelsLastConversion' not in model.config.config['HLSConfig']['Model']:
2121
node.channels_last_converted = True
2222
return False
2323
outshape = node.get_output_variable().shape
2424

2525
if isinstance(node, Input):
2626
# if inputs are not yet transposed into channels_last, add transpose layer
27-
if not model.config.config['HLSConfig']['Model']['InputsChannelLast'] and len(outshape) > 1:
27+
if model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "full" and len(outshape) > 1:
2828
# Add transpose for input layer
2929
input = node.name
3030
if len(outshape) == 2:
@@ -39,7 +39,7 @@ def transform(self, model, node):
3939
transpose_node.channels_last_converted = True
4040

4141
model.insert_node(transpose_node)
42-
else:
42+
elif model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal" and len(outshape) > 1:
4343
input_shape = node.get_output_variable().shape
4444
input_shape.append(input_shape.pop(0))
4545
node.get_output_variable().shape = input_shape

hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ struct gru_config {
8787
// Resource reuse info
8888
static const unsigned io_type = io_parallel;
8989
static const unsigned reuse_factor = 1;
90+
static const bool pytorch_order = false;
9091
static const bool store_weights_in_bram = false;
9192

9293
// Activation
@@ -133,7 +134,10 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
133134
hls_register typename CONFIG_T::accum_t hadamard_r_h[CONFIG_T::n_units];
134135
#pragma unroll recurrent_unroll_factor
135136
for (int i = 0; i < (CONFIG_T::n_units); i++) {
136-
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
137+
if (CONFIG_T::pytorch_order)
138+
hadamard_r_h[i] = z_r_act[i] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
139+
else
140+
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
137141
}
138142

139143
// The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h}
@@ -152,7 +156,11 @@ void gru_cell(data_T x[CONFIG_T::n_in], res_T h[CONFIG_T::n_units],
152156
// Update state
153157
#pragma unroll recurrent_unroll_factor
154158
for (int i = 0; i < (CONFIG_T::n_units); i++) {
155-
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
159+
if (CONFIG_T::pytorch_order)
160+
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i + CONFIG_T::n_units]) +
161+
h[i] * z_r_act[i + CONFIG_T::n_units]);
162+
else
163+
h[i] = static_cast<res_T>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
156164
}
157165
}
158166

@@ -315,6 +323,131 @@ void simple_rnn(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[C
315323
}
316324
}
317325
}
326+
//----------------------
327+
// SimpleRNN with pytorch biases
328+
//----------------------
329+
330+
struct simpleRNN_pytorch_config {
331+
// Internal data type definitions
332+
typedef float weight_t;
333+
typedef float bias_t;
334+
typedef float accum_t;
335+
336+
// Layer Sizes
337+
static const unsigned n_in = 1;
338+
static const unsigned n_out = 1;
339+
static const unsigned n_outputs = 1;
340+
static const unsigned n_timesteps = 1;
341+
static const bool return_sequences = false;
342+
343+
// Resource reuse info
344+
static const unsigned io_type = io_parallel;
345+
static const unsigned reuse_factor = 1;
346+
static const bool store_weights_in_bram = false;
347+
348+
// Activation
349+
template <class x_T, class y_T, class config_T> using activation_recr = nnet::activation::relu<x_T, y_T, config_T>;
350+
351+
template <class x_T, class y_T, class config_T> using activation = nnet::activation::relu<x_T, y_T, config_T>;
352+
};
353+
354+
template <class data_T, class res_T, typename CONFIG_T>
355+
void simple_rnn_pytorch_cell(data_T inputs[CONFIG_T::n_in], res_T hidden_state[CONFIG_T::n_out],
356+
res_T hidden_state_o[CONFIG_T::n_out],
357+
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
358+
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
359+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
360+
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
361+
// Weight multiplication
362+
typename CONFIG_T::accum_t afterW[CONFIG_T::n_out] hls_register;
363+
multiply_W<data_T, typename CONFIG_T::accum_t, typename CONFIG_T::weight_t, CONFIG_T::n_in, CONFIG_T::n_out>(
364+
inputs, afterW, kernel);
365+
366+
// Bias addition
367+
typename CONFIG_T::accum_t afterBias[CONFIG_T::n_out] hls_register;
368+
add_bias<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::bias_t, CONFIG_T::n_out>(
369+
afterW, afterBias, bias);
370+
371+
// Hidden state
372+
typename CONFIG_T::accum_t hiddenCand[CONFIG_T::n_out] hls_register;
373+
multiply_U<data_T, typename CONFIG_T::accum_t, typename CONFIG_T::weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
374+
rec_kernel);
375+
376+
// Hidden state bias addition
377+
typename CONFIG_T::accum_t hiddenBias[CONFIG_T::n_out] hls_register;
378+
add_bias<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::bias_t, CONFIG_T::n_out>(
379+
hiddenCand, hiddenBias, rec_bias);
380+
381+
// Vector addition
382+
typename CONFIG_T::accum_t afterAdd[CONFIG_T::n_out];
383+
add_vectors<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, CONFIG_T::n_out>(afterBias, hiddenBias, afterAdd);
384+
385+
// Activation
386+
CONFIG_T::template activation<typename CONFIG_T::accum_t, data_T, typename CONFIG_T::ACT_CONFIG_T>::activation(
387+
afterAdd, hidden_state_o);
388+
}
389+
390+
template <class data_T, class res_T, typename CONFIG_T>
391+
void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in],
392+
res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
393+
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
394+
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
395+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
396+
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
397+
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
398+
res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
399+
res_T h[CONFIG_T::n_out] hls_register;
400+
data_T in[CONFIG_T::n_in] hls_register;
401+
402+
// Set initially hidden state (output) to zero
403+
INIT_LOOP:
404+
#pragma unroll
405+
for (int x = 0; x < CONFIG_T::n_out; x++) {
406+
hidden_state[x][0] = 0;
407+
}
408+
409+
#pragma disable_loop_pipelining
410+
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
411+
412+
// Data at current time step
413+
#pragma unroll
414+
for (int x = 0; x < CONFIG_T::n_in; x++) {
415+
in[x] = data[x + i * CONFIG_T::n_in];
416+
}
417+
418+
// Hidden state at current time step
419+
#pragma unroll
420+
for (int x = 0; x < CONFIG_T::n_out; x++) {
421+
hidden_state_temp[x] = hidden_state[x][i];
422+
}
423+
424+
// Do SimpleRNN
425+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
426+
427+
// Write result
428+
#pragma unroll
429+
for (int x = 0; x < CONFIG_T::n_out; x++) {
430+
hidden_state[x][i + 1] = h[x];
431+
}
432+
}
433+
434+
if (CONFIG_T::return_sequences == 0) {
435+
// Output when return_sequences is false
436+
#pragma unroll
437+
for (int x = 0; x < CONFIG_T::n_out; x++) {
438+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
439+
}
440+
} else {
441+
// Output when return_sequences is true
442+
#pragma unroll
443+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
444+
#pragma unroll
445+
for (int h = 0; h < CONFIG_T::n_out; h++) {
446+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
447+
}
448+
}
449+
}
450+
}
318451

319452
//----------------------
320453
// LSTM

0 commit comments

Comments
 (0)