Skip to content

Commit 033d438

Browse files
authored
Merge pull request #907 from calad0i/fix_repack_precision
Let repack_stream optimizer inheirt original precision
2 parents 22a1054 + de20cfc commit 033d438

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

hls4ml/backends/fpga/passes/repack_stream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def transform(self, model, node):
5959

6060
# Insert new Repack node instead of Reshape
6161
repack_layer = model.make_node(Repack, 'repack_' + node.name, attrs, node.inputs.copy())
62+
# As result_t attribute is not honored by type conversion, set it manually here
63+
repack_layer.attributes[repack_layer.name].type = node.attributes[node.name].type
6264
model.replace_node(node, repack_layer)
6365

6466
return True

test/pytest/test_repack_precision.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from tensorflow import keras
2+
3+
from hls4ml.converters import convert_from_keras_model
4+
5+
6+
def test_repack_precision():
7+
inp = keras.Input(shape=(3, 3), name='inp')
8+
out = keras.layers.Reshape((3, 3), name='reshape')(inp)
9+
out = keras.layers.Conv1D(2, 2, name='conv')(out)
10+
model = keras.Model(inp, out)
11+
12+
layer_conf = {
13+
'inp': {'Precision': 'fixed<20,10>'},
14+
'reshape': {'Precision': 'fixed<20,10>'},
15+
'conv': {'Precision': 'fixed<20,10>'},
16+
}
17+
18+
hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf}
19+
20+
# Repack only happens in io_stream
21+
model_hls = convert_from_keras_model(model, hls_config=hls_config, io_type='io_stream')
22+
assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph'
23+
repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision
24+
assert repack_precision.integer == 10, 'Precision mismatch'
25+
assert repack_precision.fractional == 10, 'Precision mismatch'
26+
assert repack_precision.width == 20, 'Precision mismatch'
27+
assert repack_precision.signed is True, 'Precision mismatch'

0 commit comments

Comments
 (0)