Skip to content

Commit 93e759c

Browse files
authored
Merge pull request #934 from calad0i/replace_node_improvment
better repalce_node fn
2 parents 7916ff5 + 40d5461 commit 93e759c

File tree

3 files changed

+89
-34
lines changed

3 files changed

+89
-34
lines changed

hls4ml/model/graph.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,24 @@ def replace_node(self, old_node, new_node):
577577
new_node (Layer): The new node
578578
579579
"""
580-
prev_node = self.graph.get(old_node.inputs[0])
581-
next_node = next((x for x in self.graph.values() if x.inputs[0] == old_node.outputs[0]), None)
582-
if next_node is not None:
583-
next_node.inputs[0] = new_node.outputs[0]
584-
if prev_node is not None:
585-
if new_node.inputs is None or len(new_node.inputs) == 0: # Check if already rewired
586-
new_node.inputs = [prev_node.outputs[0]]
580+
581+
# fmt: off
582+
assert len(new_node.inputs) == len(old_node.inputs), \
583+
f'{new_node.name} and {old_node.name} have different number of inputs'
584+
assert len(new_node.outputs) == len(old_node.outputs), \
585+
f'{new_node.name} and {old_node.name} have different number of outputs'
586+
# fmt: on
587+
588+
repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)}
589+
repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)})
590+
591+
for node in self.graph.values():
592+
for i, n in enumerate(node.inputs):
593+
if n in repl:
594+
node.inputs[i] = repl[n]
595+
for i, n in enumerate(node.outputs):
596+
if n in repl:
597+
node.outputs[i] = repl[n]
587598

588599
self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())
589600
self._update_model_outputs()
@@ -648,7 +659,9 @@ def compile(self):
648659
Users should call this function if they want to use `predict` functionality for simulation.
649660
"""
650661
self.write()
662+
self._compile()
651663

664+
def _compile(self):
652665
lib_name = self.config.backend.compile(self)
653666
if self._top_function_lib is not None:
654667
if platform.system() == "Linux":

test/pytest/test_repack_precision.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

test/pytest/test_repack_stream.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
from tensorflow import keras
6+
7+
from hls4ml.converters import convert_from_keras_model
8+
9+
test_root_path = Path(__file__).parent
10+
11+
12+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
13+
def test_repack_precision(backend: str):
14+
inp = keras.Input(shape=(3, 3), name='inp')
15+
out = keras.layers.Reshape((3, 3), name='reshape')(inp)
16+
out = keras.layers.Conv1D(2, 2, name='conv')(out)
17+
model = keras.Model(inp, out)
18+
19+
layer_conf = {
20+
'inp': {'Precision': 'fixed<20,10>'},
21+
'reshape': {'Precision': 'fixed<20,10>'},
22+
'conv': {'Precision': 'fixed<20,10>'},
23+
}
24+
25+
hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf}
26+
27+
# Repack only happens in io_stream
28+
model_hls = convert_from_keras_model(
29+
model,
30+
backend=backend,
31+
output_dir=str(test_root_path / f'hls4mlprj_repack_precision_{backend}'),
32+
hls_config=hls_config,
33+
io_type='io_stream',
34+
)
35+
model_hls.write() # Not needed for this test, but useful for debugging
36+
assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph'
37+
repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision
38+
assert repack_precision.integer == 10, 'Precision mismatch'
39+
assert repack_precision.fractional == 10, 'Precision mismatch'
40+
assert repack_precision.width == 20, 'Precision mismatch'
41+
assert repack_precision.signed is True, 'Precision mismatch'
42+
43+
44+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
45+
@pytest.mark.parametrize('strategy', ['Latency', 'Resource'])
46+
def test_repack(backend: str, strategy: str):
47+
inp1 = keras.Input(shape=(4,), name='inp1')
48+
inp2 = keras.Input(shape=(4,), name='inp2')
49+
r1 = keras.layers.Reshape((2, 2), name='reshape1')(inp1)
50+
r2 = keras.layers.Reshape((2, 2), name='reshape2')(inp2)
51+
out = keras.layers.Concatenate(name='concat')([r1, r2])
52+
model = keras.Model([inp1, inp2], out)
53+
54+
hls_config = {'Model': {'Precision': 'ap_ufixed<8,8>', 'ReuseFactor': 1}, 'Strategy': strategy}
55+
model_hls = convert_from_keras_model(
56+
model,
57+
io_type='io_stream',
58+
backend=backend,
59+
hls_config=hls_config,
60+
output_dir=str(test_root_path / f'hls4mlprj_repack_{backend}_{strategy}'),
61+
)
62+
model_hls.compile()
63+
inp_data = [
64+
np.random.randint(0, 2**8, (100, 4)).astype(np.float32),
65+
np.random.randint(0, 2**8, (100, 4)).astype(np.float32),
66+
]
67+
out_target = np.concatenate([inp_data[0].reshape(100, 2, 2), inp_data[1].reshape(100, 2, 2)], axis=-1)
68+
out_data: np.ndarray = model_hls.predict(inp_data) # type: ignore
69+
assert np.all(out_data.reshape(out_target.shape) == out_target), 'Concatenate failed: mismatching output'

0 commit comments

Comments
 (0)