Skip to content

Use correct number of args for multiple outputs #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def _get_top_function(self, x):
xlist = [x]
else:
xlist = x
n_outputs = len(self.get_output_variables())

for xi in xlist:
if not isinstance(xi, np.ndarray):
Expand All @@ -610,7 +611,7 @@ def _get_top_function(self, x):


top_function.restype = None
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist)+1)]
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist) + n_outputs)]

return top_function, ctype

Expand All @@ -637,6 +638,7 @@ def predict(self, x):
top_function, ctype = self._get_top_function(x)
n_samples = self._compute_n_samples(x)
n_inputs = len(self.get_input_variables())
n_outputs = len(self.get_output_variables())

curr_dir = os.getcwd()
os.chdir(self.config.get_output_dir() + '/firmware')
Expand All @@ -647,25 +649,29 @@ def predict(self, x):

try:
for i in range(n_samples):
predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype)
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
if n_inputs == 1:
top_function(x[i], predictions)
inp = [x[i]]
else:
inp = [xj[i] for xj in x]
argtuple = inp
argtuple += [predictions]
argtuple = tuple(argtuple)
top_function(*argtuple)
argtuple = inp
argtuple += predictions
argtuple = tuple(argtuple)
top_function(*argtuple)
output.append(predictions)


#Convert to numpy array
output = np.asarray(output)
# Convert to list of numpy arrays (one for each output)
output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)]
finally:
os.chdir(curr_dir)

if n_samples == 1:

if n_samples == 1 and n_outputs == 1:
return output[0][0]
elif n_outputs == 1:
return output[0]
elif n_samples == 1:
return [output_i[0] for output_i in output]
else:
return output

Expand Down
28 changes: 28 additions & 0 deletions test/pytest/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,31 @@ def test_broadcast_stream(shapes, layer):
y = model.predict([X1, X2])
y_hls = hls_model.predict([X1, X2]).reshape(y.shape)
np.testing.assert_allclose(y, y_hls, rtol=0)

@pytest.mark.parametrize('batch', [1, 32])
def test_multiple_outputs(batch):
''' Test case for multple outputs '''
input1 = tf.keras.layers.Input(shape=(10,))
inputs = [input1]
output1 = tf.keras.layers.Dense(5, kernel_initializer='ones', use_bias=False)(input1)
output2 = tf.keras.layers.Dense(2, kernel_initializer='ones', use_bias=False)(input1)
outputs = [output1, output2]
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

# create the ModelGraph
config = hls4ml.utils.config_from_keras_model(model, granularity='model', default_precision='ap_fixed<32,16>')
odir = str(test_root_path / 'hls4mlprj_graph_multiple_outputs')
hls_model = hls4ml.converters.convert_from_keras_model(model,
output_dir=odir,
backend='Vivado',
io_type='io_serial',
hls_config=config)
hls_model.compile()

# Test with integers (for exact agreement)
X1 = np.random.randint(0, 100, size=(batch, 10)).astype(float)
y = model.predict(X1)
y_hls = hls_model.predict(X1)
for y_i, y_hls_i in zip(y, y_hls):
y_hls_i = y_hls_i.reshape(y_i.shape)
np.testing.assert_allclose(y_i, y_hls_i, rtol=0)