Skip to content

Commit bcc1d28

Browse files
authored
Merge pull request fastmachinelearning#487 from apfusco/predict-multiple-outputs
Use correct number of args for multiple outputs
2 parents 37ae6cc + 792409f commit bcc1d28

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

hls4ml/model/graph.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _get_top_function(self, x):
591591
xlist = [x]
592592
else:
593593
xlist = x
594+
n_outputs = len(self.get_output_variables())
594595

595596
for xi in xlist:
596597
if not isinstance(xi, np.ndarray):
@@ -610,7 +611,7 @@ def _get_top_function(self, x):
610611

611612

612613
top_function.restype = None
613-
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist)+1)]
614+
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist) + n_outputs)]
614615

615616
return top_function, ctype
616617

@@ -637,6 +638,7 @@ def predict(self, x):
637638
top_function, ctype = self._get_top_function(x)
638639
n_samples = self._compute_n_samples(x)
639640
n_inputs = len(self.get_input_variables())
641+
n_outputs = len(self.get_output_variables())
640642

641643
curr_dir = os.getcwd()
642644
os.chdir(self.config.get_output_dir() + '/firmware')
@@ -647,25 +649,29 @@ def predict(self, x):
647649

648650
try:
649651
for i in range(n_samples):
650-
predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype)
652+
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
651653
if n_inputs == 1:
652-
top_function(x[i], predictions)
654+
inp = [x[i]]
653655
else:
654656
inp = [xj[i] for xj in x]
655-
argtuple = inp
656-
argtuple += [predictions]
657-
argtuple = tuple(argtuple)
658-
top_function(*argtuple)
657+
argtuple = inp
658+
argtuple += predictions
659+
argtuple = tuple(argtuple)
660+
top_function(*argtuple)
659661
output.append(predictions)
660662

661663

662-
#Convert to numpy array
663-
output = np.asarray(output)
664+
# Convert to list of numpy arrays (one for each output)
665+
output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)]
664666
finally:
665667
os.chdir(curr_dir)
666-
667-
if n_samples == 1:
668+
669+
if n_samples == 1 and n_outputs == 1:
670+
return output[0][0]
671+
elif n_outputs == 1:
668672
return output[0]
673+
elif n_samples == 1:
674+
return [output_i[0] for output_i in output]
669675
else:
670676
return output
671677

test/pytest/test_graph.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,31 @@ def test_broadcast_stream(shapes, layer):
175175
y = model.predict([X1, X2])
176176
y_hls = hls_model.predict([X1, X2]).reshape(y.shape)
177177
np.testing.assert_allclose(y, y_hls, rtol=0)
178+
179+
@pytest.mark.parametrize('batch', [1, 32])
180+
def test_multiple_outputs(batch):
181+
''' Test case for multple outputs '''
182+
input1 = tf.keras.layers.Input(shape=(10,))
183+
inputs = [input1]
184+
output1 = tf.keras.layers.Dense(5, kernel_initializer='ones', use_bias=False)(input1)
185+
output2 = tf.keras.layers.Dense(2, kernel_initializer='ones', use_bias=False)(input1)
186+
outputs = [output1, output2]
187+
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
188+
189+
# create the ModelGraph
190+
config = hls4ml.utils.config_from_keras_model(model, granularity='model', default_precision='ap_fixed<32,16>')
191+
odir = str(test_root_path / 'hls4mlprj_graph_multiple_outputs')
192+
hls_model = hls4ml.converters.convert_from_keras_model(model,
193+
output_dir=odir,
194+
backend='Vivado',
195+
io_type='io_serial',
196+
hls_config=config)
197+
hls_model.compile()
198+
199+
# Test with integers (for exact agreement)
200+
X1 = np.random.randint(0, 100, size=(batch, 10)).astype(float)
201+
y = model.predict(X1)
202+
y_hls = hls_model.predict(X1)
203+
for y_i, y_hls_i in zip(y, y_hls):
204+
y_hls_i = y_hls_i.reshape(y_i.shape)
205+
np.testing.assert_allclose(y_i, y_hls_i, rtol=0)

0 commit comments

Comments
 (0)