@@ -591,6 +591,7 @@ def _get_top_function(self, x):
591
591
xlist = [x ]
592
592
else :
593
593
xlist = x
594
+ n_outputs = len (self .get_output_variables ())
594
595
595
596
for xi in xlist :
596
597
if not isinstance (xi , np .ndarray ):
@@ -610,7 +611,7 @@ def _get_top_function(self, x):
610
611
611
612
612
613
top_function .restype = None
613
- top_function .argtypes = [npc .ndpointer (ctype , flags = "C_CONTIGUOUS" ) for i in range (len (xlist )+ 1 )]
614
+ top_function .argtypes = [npc .ndpointer (ctype , flags = "C_CONTIGUOUS" ) for i in range (len (xlist ) + n_outputs )]
614
615
615
616
return top_function , ctype
616
617
@@ -637,6 +638,7 @@ def predict(self, x):
637
638
top_function , ctype = self ._get_top_function (x )
638
639
n_samples = self ._compute_n_samples (x )
639
640
n_inputs = len (self .get_input_variables ())
641
+ n_outputs = len (self .get_output_variables ())
640
642
641
643
curr_dir = os .getcwd ()
642
644
os .chdir (self .config .get_output_dir () + '/firmware' )
@@ -647,25 +649,29 @@ def predict(self, x):
647
649
648
650
try :
649
651
for i in range (n_samples ):
650
- predictions = np .zeros (self . get_output_variables ()[ 0 ]. size (), dtype = ctype )
652
+ predictions = [ np .zeros (yj . size (), dtype = ctype ) for yj in self . get_output_variables ()]
651
653
if n_inputs == 1 :
652
- top_function ( x [i ], predictions )
654
+ inp = [ x [i ]]
653
655
else :
654
656
inp = [xj [i ] for xj in x ]
655
- argtuple = inp
656
- argtuple += [ predictions ]
657
- argtuple = tuple (argtuple )
658
- top_function (* argtuple )
657
+ argtuple = inp
658
+ argtuple += predictions
659
+ argtuple = tuple (argtuple )
660
+ top_function (* argtuple )
659
661
output .append (predictions )
660
662
661
663
662
- #Convert to numpy array
663
- output = np .asarray (output )
664
+ # Convert to list of numpy arrays (one for each output)
665
+ output = [ np .asarray ([ output [ i_sample ][ i_output ] for i_sample in range ( n_samples )]) for i_output in range ( n_outputs )]
664
666
finally :
665
667
os .chdir (curr_dir )
666
-
667
- if n_samples == 1 :
668
+
669
+ if n_samples == 1 and n_outputs == 1 :
670
+ return output [0 ][0 ]
671
+ elif n_outputs == 1 :
668
672
return output [0 ]
673
+ elif n_samples == 1 :
674
+ return [output_i [0 ] for output_i in output ]
669
675
else :
670
676
return output
671
677
0 commit comments