Skip to content

Commit 1c302ad

Browse files
apfuscojmduarte
authored andcommitted
predict: Fix number of args for multiple outputs
1 parent d188120 commit 1c302ad

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

hls4ml/model/graph.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _get_top_function(self, x):
591591
xlist = [x]
592592
else:
593593
xlist = x
594+
n_outputs = len(self.get_output_variables())
594595

595596
for xi in xlist:
596597
if not isinstance(xi, np.ndarray):
@@ -610,7 +611,7 @@ def _get_top_function(self, x):
610611

611612

612613
top_function.restype = None
613-
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist)+1)]
614+
top_function.argtypes = [npc.ndpointer(ctype, flags="C_CONTIGUOUS") for i in range(len(xlist) + n_outputs)]
614615

615616
return top_function, ctype
616617

@@ -637,6 +638,7 @@ def predict(self, x):
637638
top_function, ctype = self._get_top_function(x)
638639
n_samples = self._compute_n_samples(x)
639640
n_inputs = len(self.get_input_variables())
641+
n_outputs = len(self.get_output_variables())
640642

641643
curr_dir = os.getcwd()
642644
os.chdir(self.config.get_output_dir() + '/firmware')
@@ -647,20 +649,20 @@ def predict(self, x):
647649

648650
try:
649651
for i in range(n_samples):
650-
predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype)
652+
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
651653
if n_inputs == 1:
652-
top_function(x[i], predictions)
654+
inp = [x[i]]
653655
else:
654656
inp = [xj[i] for xj in x]
655-
argtuple = inp
656-
argtuple += [predictions]
657-
argtuple = tuple(argtuple)
658-
top_function(*argtuple)
657+
argtuple = inp
658+
argtuple += predictions
659+
argtuple = tuple(argtuple)
660+
top_function(*argtuple)
659661
output.append(predictions)
660662

661663

662664
#Convert to numpy array
663-
output = np.asarray(output)
665+
output = np.asarray(output, dtype=object)
664666
finally:
665667
os.chdir(curr_dir)
666668

0 commit comments

Comments
 (0)