Skip to content

Commit 6b9e91f

Browse files
committed
Quartus GRU
1 parent 19c541a commit 6b9e91f

File tree

8 files changed

+435
-44
lines changed

8 files changed

+435
-44
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from hls4ml.backends.backend import get_backend
2+
from hls4ml.model.layers import GRU
3+
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate
4+
5+
recurrent_include_list = ['nnet_utils/nnet_recurrent.h', 'nnet_utils/nnet_recurrent_stream.h']
6+
7+
# Shared Matrix Multiplication Template (Dense)
8+
recr_mult_config_template = '''struct config{index}_mult : nnet::dense_config {{
9+
static const unsigned n_in = {n_in};
10+
static const unsigned n_out = {n_out};
11+
12+
static const unsigned rf_pad = {rfpad};
13+
static const unsigned bf_pad = {bfpad};
14+
static const unsigned reuse_factor = {reuse};
15+
static const unsigned reuse_factor_rounded = reuse_factor + rf_pad;
16+
static const unsigned block_factor = DIV_ROUNDUP(n_in*n_out, reuse_factor);
17+
static const unsigned block_factor_rounded = block_factor + bf_pad;
18+
static const unsigned multiplier_factor = MIN(n_in, reuse_factor);
19+
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in*n_out, multiplier_factor);
20+
static const unsigned multiplier_scale = multiplier_limit/n_out;
21+
typedef {accum_t.name} accum_t;
22+
typedef {bias_t.name} bias_t;
23+
typedef {weight_t.name} weight_t;
24+
25+
template<class x_T, class y_T>
26+
using product = nnet::product::{product_type}<x_T, y_T>;
27+
}};\n'''
28+
29+
# Activation Template
30+
activ_config_template = '''struct {type}_config{index} : nnet::activ_config {{
31+
static const unsigned n_in = {n_in};
32+
static const unsigned table_size = {table_size};
33+
static const unsigned io_type = nnet::{iotype};
34+
static const unsigned reuse_factor = {reuse};
35+
}};\n'''
36+
37+
# GRU Template
38+
gru_config_template = '''struct config{index} : nnet::gru_config {{
39+
static const unsigned n_in = {n_in};
40+
static const unsigned n_out = {n_out};
41+
static const unsigned n_units = {n_units};
42+
static const unsigned n_timesteps = {n_timesteps};
43+
static const unsigned n_outputs = {n_outputs};
44+
static const bool return_sequences = {return_sequences};
45+
46+
typedef {accum_t.name} accum_t;
47+
typedef {weight_t.name} weight_t;
48+
typedef {bias_t.name} bias_t;
49+
50+
typedef {config_mult_x} mult_config_x;
51+
typedef {config_mult_h} mult_config_h;
52+
53+
typedef {act_t} ACT_CONFIG_T;
54+
template<class x_T, class y_T, class config_T>
55+
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
56+
57+
typedef {act_recurrent_t} ACT_CONFIG_RECURRENT_T;
58+
template<class x_T, class y_T, class config_T>
59+
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
60+
61+
static const unsigned reuse_factor = {reuse};
62+
static const bool store_weights_in_bram = false;
63+
}};\n'''
64+
65+
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
66+
67+
class GRUConfigTemplate(LayerConfigTemplate):
68+
def __init__(self):
69+
super().__init__(GRU)
70+
self.gru_template = gru_config_template
71+
self.act_template = activ_config_template
72+
self.recr_act_template = activ_config_template
73+
self.mult_x_template = recr_mult_config_template
74+
self.mult_h_template = recr_mult_config_template
75+
76+
def format(self, node):
77+
# Input has shape (n_timesteps, inp_dimensionality)
78+
# Output / hidden units has shape (1 if !return_sequences else n_timesteps , n_units)
79+
params = self._default_config_params(node)
80+
params['n_units'] = node.get_attr('n_out')
81+
params['n_outputs'] = node.get_attr('n_timesteps') if node.get_attr('return_sequences', False) else '1'
82+
params['return_sequences'] ='true' if node.get_attr('return_sequences', False) else 'false'
83+
params['config_mult_x'] = 'config{}_x_mult'.format(node.index)
84+
params['config_mult_h'] = 'config{}_h_mult'.format(node.index)
85+
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
86+
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
87+
gru_config = self.gru_template.format(**params)
88+
89+
# Activation is on candidate hidden state, dimensionality (1, n_units)
90+
act_params = self._default_config_params(node)
91+
act_params['type'] = node.get_attr('activation')
92+
act_params['n_in'] = node.get_attr('n_out')
93+
act_params['index'] = str(node.index) + '_act'
94+
act_config = self.act_template.format(**act_params)
95+
96+
# Recurrent activation is on reset and update gates (therefore x2), dimensionality (1, n_units)
97+
recr_act_params = self._default_config_params(node)
98+
recr_act_params['type'] = node.get_attr('recurrent_activation')
99+
recr_act_params['n_in'] = str(node.get_attr('n_out')) + ' * 2'
100+
recr_act_params['index'] = str(node.index) + '_rec_act'
101+
recr_act_config = self.recr_act_template.format(**recr_act_params)
102+
103+
# Multiplication config for matrix multiplications of type Wx (reset, update and candidate states)
104+
mult_params_x = self._default_config_params(node)
105+
mult_params_x['n_in'] = node.get_attr('n_in')
106+
mult_params_x['n_out'] = str(node.get_attr('n_out')) + ' * 3'
107+
mult_params_x['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
108+
mult_params_x['index'] = str(node.index) + '_x'
109+
mult_config_x = self.mult_x_template.format(**mult_params_x)
110+
111+
# Multiplication config for matrix multiplications of type Wh (reset, update and candidate states)
112+
mult_params_h = self._default_config_params(node)
113+
mult_params_h['n_in'] = node.get_attr('n_out')
114+
mult_params_h['n_out'] = str(node.get_attr('n_out')) + ' * 3'
115+
mult_params_h['reuse_factor'] = params['recurrent_reuse_factor']
116+
mult_params_h['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision)
117+
mult_params_h['index'] = str(node.index) + '_h'
118+
mult_config_h = self.mult_h_template.format(**mult_params_h)
119+
120+
return mult_config_x + '\n' + mult_config_h + '\n' + recr_act_config + '\n' + act_config + '\n' + gru_config
121+
122+
class GRUFunctionTemplate(FunctionCallTemplate):
123+
def __init__(self):
124+
super().__init__(GRU, include_header=recurrent_include_list)
125+
self.template = gru_function_template
126+
127+
def format(self, node):
128+
params = self._default_function_params(node)
129+
params['w'] = node.get_weights('weight').name
130+
params['b'] = node.get_weights('bias').name
131+
params['wr'] = node.get_weights('recurrent_weight').name
132+
params['br'] = node.get_weights('recurrent_bias').name
133+
return self.template.format(**params)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
from hls4ml.model.optimizer import OptimizerPass
3+
from hls4ml.model.layers import Dense, GRU
4+
5+
class ApplyResourceStrategy(OptimizerPass):
6+
''' Transposes the weights to use the dense_resource matrix multiply routine '''
7+
def match(self, node):
8+
node_matches = isinstance(node, (Dense, GRU))
9+
is_resource_strategy = True # node.get_attr('strategy', '').lower() == 'resource' ... Quartus only supports resource strategy
10+
already_transformed = node.get_attr('_weights_transposed', False) == True
11+
return node_matches and is_resource_strategy and not already_transformed
12+
13+
def transform(self, model, node):
14+
if isinstance(node, Dense) and not node.model.config.get_compression(node):
15+
rf = node.get_attr('reuse_factor')
16+
bf = int((node.attributes['n_in']*node.attributes['n_out'])/rf)
17+
bf_rounded = int(pow(2, np.ceil(np.log2(bf))))
18+
rf_rounded = int(pow(2, np.ceil(np.log2(rf))))
19+
20+
node.weights['weight'].data = np.transpose(node.weights['weight'].data).flatten()
21+
22+
if(node.attributes['n_in']*node.attributes['n_out'] > 2048 and rf_rounded != rf):
23+
node.set_attr('rfpad', rf_rounded-rf)
24+
node.set_attr('bfpad', bf_rounded-bf)
25+
26+
temp = np.empty([bf_rounded, rf_rounded])
27+
for i in range(rf_rounded):
28+
for j in range (bf_rounded):
29+
if (i < rf and j < bf):
30+
w_index = i + rf * j
31+
temp[j][i] = node.weights['weight'].data[w_index]
32+
else:
33+
temp[j][i] = 0
34+
node.weights['weight'].data = temp.flatten()
35+
node.weights['weight'].data_length = node.weights['weight'].data.size
36+
37+
elif isinstance(node, GRU):
38+
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
39+
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
40+
41+
else:
42+
raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name))
43+
44+
node.set_attr('_weights_transposed', True)
45+
return False
46+

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1-
import numpy as np
2-
import math
31
import os
4-
import copy
5-
import webbrowser
6-
from calmjs.parse import es5
7-
from calmjs.parse import asttypes
8-
from tabulate import tabulate
9-
from ast import literal_eval
2+
from hls4ml.model.attributes import Attribute
3+
import numpy as np
104
from contextlib import contextmanager
115

6+
from hls4ml.backends import FPGABackend
127
from hls4ml.model.types import NamedType, IntegerPrecisionType, FixedPrecisionType
13-
from hls4ml.model.layers import Embedding, Layer, Dense, BatchNormalization, Activation, ParametrizedActivation, PReLU, Softmax
14-
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer, model_optimizer
8+
from hls4ml.model.layers import Embedding, Layer, Dense, Activation, Softmax, GRU
159
from hls4ml.model.flow import register_flow
16-
from hls4ml.backends import FPGABackend
1710
from hls4ml.report import parse_quartus_report
11+
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer
1812

1913
@contextmanager
2014
def chdir(newdir):
@@ -28,14 +22,23 @@ def chdir(newdir):
2822
class QuartusBackend(FPGABackend):
2923
def __init__(self):
3024
super(QuartusBackend, self).__init__('Quartus')
25+
self._register_layer_attributes()
3126
self._register_flows()
3227

28+
def _register_layer_attributes(self):
29+
extended_attrs = {
30+
GRU: [Attribute('recurrent_reuse_factor', default=1)],
31+
}
32+
self.attribute_map.update(extended_attrs)
33+
34+
3335
def _register_flows(self):
3436
initializers = self._get_layer_initializers()
3537
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
3638

3739
quartus_types = [
3840
'quartus:transform_types',
41+
'quartus:apply_resource_strategy'
3942
]
4043
quartus_types_flow = register_flow('specific_types', quartus_types, requires=[init_flow], backend=self.name)
4144

@@ -90,31 +93,6 @@ def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_para
9093

9194
return config
9295

93-
def gen_quartus_weight_array(self, layer):
94-
rf = layer.get_attr('reuse_factor')
95-
block_factor = int((layer.attributes['n_in']*layer.attributes['n_out'])/rf)
96-
bf_rounded = int(pow(2, np.ceil(np.log2(block_factor))))
97-
rf_rounded = int(pow(2, np.ceil(np.log2(rf))))
98-
99-
layer.weights['weight'].data = np.transpose(layer.weights['weight'].data).flatten()
100-
101-
if(layer.attributes['n_in']*layer.attributes['n_out'] > 2048 and rf_rounded != rf):
102-
layer.set_attr('rfpad', rf_rounded-rf)
103-
layer.set_attr('bfpad', bf_rounded-block_factor)
104-
105-
temp = np.empty([bf_rounded, rf_rounded])
106-
for i in range(rf_rounded):
107-
for j in range (bf_rounded):
108-
if (i < rf and j < block_factor):
109-
w_index = i + rf * j
110-
temp[j][i] = layer.weights['weight'].data[w_index]
111-
else:
112-
temp[j][i] = 0
113-
layer.weights['weight'].data = temp.flatten()
114-
115-
layer.weights['weight'].data_length = layer.weights['weight'].data.size
116-
return
117-
11896
def build(self, model, synth=True, fpgasynth=False):
11997
"""
12098
Builds the project using Intel HLS compiler.
@@ -167,7 +145,6 @@ def init_dense(self, layer):
167145
else:
168146
n_in, n_out = self.get_layer_mult_size(layer)
169147
self.set_closest_reuse_factor(layer, n_in, n_out)
170-
self.gen_quartus_weight_array(layer)
171148
layer.set_attr('strategy', 'resource')
172149

173150
if layer.model.config.is_resource_strategy(layer):
@@ -198,4 +175,27 @@ def init_softmax(self, layer):
198175
@layer_optimizer(Embedding)
199176
def init_embed(self, layer):
200177
if layer.attributes['n_in'] is None:
201-
raise Exception('Input length of Embedding layer must be specified.')
178+
raise Exception('Input length of Embedding layer must be specified.')
179+
180+
@layer_optimizer(GRU)
181+
def init_gru(self, layer):
182+
reuse_factor = layer.model.config.get_reuse_factor(layer)
183+
layer.set_attr('recurrent_reuse_factor', reuse_factor)
184+
185+
# Dense multiplication properties
186+
layer.set_attr('rfpad', 0)
187+
layer.set_attr('bfpad', 0)
188+
189+
index_t = IntegerPrecisionType(width=1, signed=False)
190+
191+
if 'table_t' not in layer.attributes:
192+
layer.set_attr('table_t', FixedPrecisionType(width=18, integer=8))
193+
if 'table_size' not in layer.attributes:
194+
layer.set_attr('table_size', 1024)
195+
if True: # layer.model.config.is_resource_strategy(layer): ... Quartus only supports Dense resource multiplication
196+
n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)
197+
self.set_closest_reuse_factor(layer, n_in, n_out)
198+
self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor')
199+
layer.set_attr('strategy', 'resource')
200+
201+
layer.set_attr('index_t', index_t)

0 commit comments

Comments
 (0)