Skip to content

Quartus multi out with stream fix #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions hls4ml/backends/fpga/passes/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@ def initialize(self):
class CloneFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Clone, include_header=clone_include_list)
self.template = None # to be filled once number of clones known

def format(self, node):
params = self._default_function_params(node)
for i, _output in enumerate(node.outputs):
params['output' + str(i + 1)] = node.variables[node.outputs[i]].name

if self.template is None:
self.template = (
'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, '
+ ', '.join(['{output' + str(i + 1) + '}' for i in range(len(node.outputs))])
+ ');'
)
template = (
'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, '
+ ', '.join(['{output' + str(i + 1) + '}' for i in range(len(node.outputs))])
+ ');'
)

return self.template.format(**params)
return template.format(**params)


def register_clone(backend):
Expand Down
4 changes: 2 additions & 2 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def make_node(self, kind, name, attributes, inputs, outputs=None):
node = layer_cls(self, name, attributes, inputs, outputs)
for o in node.outputs:
out_var = node.get_output_variable(output_name=o)
if o in self.outputs:
if len(self.outputs) == 1 and o in self.outputs:
out_var.type.name = 'result_t'
self.output_vars[o] = out_var
return node
Expand Down Expand Up @@ -608,7 +608,7 @@ def get_input_variables(self):
return variables

def register_output_variable(self, out_name, variable):
if out_name in self.outputs:
if len(self.outputs) == 1 and out_name in self.outputs:
variable.type.name = 'result_t'
self.output_vars[out_name] = variable

Expand Down
14 changes: 9 additions & 5 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def write_project_cpp(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n) {\n'
Expand All @@ -191,7 +192,8 @@ def write_project_cpp(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'\
if model_brams:
newline += ',\n' + brams_str
newline += '\n) {\n'
Expand Down Expand Up @@ -277,7 +279,7 @@ def write_project_cpp(self, model):
newline += indent + f' {out.type.name} tmp = {out.name}.read();\n'
newline += indent + f' {out.name}_stream.write(tmp);\n'
newline += indent + '}\n'
newline += '}\n'
newline += '}\n'
else:
newline = line
newline += indent + 'return outputs;\n'
Expand Down Expand Up @@ -330,7 +332,8 @@ def write_project_header(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n);\n'
Expand All @@ -350,7 +353,8 @@ def write_project_header(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n);\n'
Expand Down
52 changes: 52 additions & 0 deletions test/pytest/test_multiout_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path

import numpy as np
import pytest
from keras.layers import Dense
from tensorflow import keras

from hls4ml.converters import convert_from_keras_model

test_root_path = Path(__file__).parent


@pytest.fixture(scope='module')
def model():
inp = keras.Input(shape=(10,))
x = Dense(10, name='dense1')(inp)
y = Dense(10, name='dense2')(inp)
model = keras.Model(inp, [x, y])
return model


@pytest.fixture(scope='module')
def data():
X = np.random.normal(0, 1, (1000, 10))
X = np.clip(X, -16, 15)
return X


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_multi_clone(model, data, backend: str, io_type: str):
output_dir = str(test_root_path / f'hls4mlprj_multiout_network_{backend}_{io_type}')
hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}}
layer_config = {
'dense1': {'Precision': {'result': 'fixed<35,5>'}},
'dense2': {'Precision': {'result': 'fixed<40,5>'}},
'dense1_linear': {'Precision': {'result': 'fixed<35,5>'}},
'dense2_linear': {'Precision': {'result': 'fixed<40,5>'}},
}
hls_config['LayerName'] = layer_config
model_hls = convert_from_keras_model(
model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type
)

assert model_hls.graph['dense1'].attributes['result_t'] != model_hls.graph['dense2'].attributes['result_t']

model_hls.compile()
r_hls = model_hls.predict(data)
r_keras = [x.numpy() for x in model(data)]

assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0)
assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0)
48 changes: 48 additions & 0 deletions test/pytest/test_stream_multi_clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path

import numpy as np
import pytest
from keras.layers import Add, Dense
from tensorflow import keras

from hls4ml.converters import convert_from_keras_model

test_root_path = Path(__file__).parent


@pytest.fixture(scope='module')
def model():
inp = keras.Input(shape=(10,))
x = Dense(10)(inp)
y = Dense(10)(inp)
z = Dense(10)(inp)
xy = Add()([x, y]) # 5
xy = Add()([xy, y]) # 5
out = Add()([xy, z]) # 5
model = keras.Model(inp, out)
return model


@pytest.fixture(scope='module')
def data():
X = np.random.normal(0, 1, (1000, 10))
X = np.clip(X, -16, 15)
return X


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis'])
def test_multi_clone(model, data, backend: str):
output_dir = str(test_root_path / f'hls4mlprj_stream_multi_clone_{backend}')
hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}}
model_hls = convert_from_keras_model(
model,
backend=backend,
output_dir=output_dir,
hls_config=hls_config,
io_type='io_stream', # clone only happens with stream io.
)
model_hls.compile()
r_hls = model_hls.predict(data)
r_keras = model(data).numpy()

assert np.allclose(r_hls, r_keras, atol=1e-5, rtol=0)