4
4
import pynq .lib .dma
5
5
import numpy as np
6
6
7
+
7
8
class NeuralNetworkOverlay (Overlay ):
8
- def __init__ (self , bitfile_name , dtbo = None , download = True , ignore_version = False , device = None ):
9
-
10
- super ().__init__ (bitfile_name , dtbo = dtbo , download = download , ignore_version = ignore_version , device = device )
11
-
9
+ def __init__ (self , bitfile_name , x_shape , y_shape , dtype = np .float32 , dtbo = None , download = True , ignore_version = False ,
10
+ device = None ):
11
+ super ().__init__ (bitfile_name , dtbo = None , download = True , ignore_version = False , device = None )
12
+ self .sendchannel = self .hier_0 .axi_dma_0 .sendchannel
13
+ self .recvchannel = self .hier_0 .axi_dma_0 .recvchannel
14
+ self .input_buffer = allocate (shape = x_shape , dtype = dtype )
15
+ self .output_buffer = allocate (shape = y_shape , dtype = dtype )
16
+
12
17
def _print_dt (self , timea , timeb , N ):
13
- dt = (timeb - timea )
14
- dts = dt .seconds + dt .microseconds * 10 ** - 6
18
+ dt = (timeb - timea )
19
+ dts = dt .seconds + dt .microseconds * 10 ** - 6
15
20
rate = N / dts
16
21
print ("Classified {} samples in {} seconds ({} inferences / s)" .format (N , dts , rate ))
17
22
return dts , rate
18
- def predict (self , X , y_shape , dtype = np .float32 , debug = None , profile = False , encode = None , decode = None ):
23
+
24
+ def predict (self , X , debug = False , profile = False , encode = None , decode = None ):
19
25
"""
20
26
Obtain the predictions of the NN implemented in the FPGA.
21
27
Parameters:
22
28
- X : the input vector. Should be numpy ndarray.
23
- - y_shape : the shape of the output vector. Needed to the accelerator to set the TLAST bit properly and
24
- for sizing the output vector shape.
25
- - dtype : the data type of the elements of the input/output vectors.
26
- Note: it should be set depending on the interface of the accelerator; if it uses 'float'
27
- types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use.
29
+ - dtype : the data type of the elements of the input/output vectors.
30
+ Note: it should be set depending on the interface of the accelerator; if it uses 'float'
31
+ types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use.
28
32
Instead if it uses 'ap_fixed<A,B>', 'np.intA' is the correct one to use (note that A cannot
29
- any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy`
33
+ any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy`
30
34
doc for more info).
31
- In this case the encoding/decoding has to be computed by the PS. For example for
32
- 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode
35
+ In this case the encoding/decoding has to be computed by the PS. For example for
36
+ 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode
33
37
'float' -> 'ap_fixed<16,6>':
34
38
```
35
39
def encode(xi):
@@ -48,24 +52,24 @@ def decode(yi):
48
52
timea = datetime .now ()
49
53
if encode is not None :
50
54
X = encode (X )
51
- with allocate (shape = X .shape , dtype = dtype ) as input_buffer , \
52
- allocate (shape = y_shape , dtype = dtype ) as output_buffer :
53
- input_buffer [:] = X
54
- self .hier_0 .axi_dma_0 .sendchannel .transfer (input_buffer )
55
- self .hier_0 .axi_dma_0 .recvchannel .transfer (output_buffer )
56
- if debug :
57
- print ("Transfer OK" )
58
- self .hier_0 .axi_dma_0 .sendchannel .wait ()
59
- if debug :
60
- print ("Send OK" )
61
- self .hier_0 .axi_dma_0 .recvchannel .wait ()
62
- if debug :
63
- print ("Receive OK" )
64
- result = output_buffer .copy ()
55
+ self .input_buffer [:] = X
56
+ self .sendchannel .transfer (self .input_buffer )
57
+ self .recvchannel .transfer (self .output_buffer )
58
+ if debug :
59
+ print ("Transfer OK" )
60
+ self .sendchannel .wait ()
61
+ if debug :
62
+ print ("Send OK" )
63
+ self .recvchannel .wait ()
64
+ if debug :
65
+ print ("Receive OK" )
66
+ # result = self.output_buffer.copy()
65
67
if decode is not None :
66
- result = decode (result )
68
+ self .output_buffer = decode (self .output_buffer )
69
+
67
70
if profile :
68
71
timeb = datetime .now ()
69
72
dts , rate = self ._print_dt (timea , timeb , len (X ))
70
- return result , dts , rate
71
- return result
73
+ return self .output_buffer , dts , rate
74
+ else :
75
+ return self .output_buffer
0 commit comments