16
16
17
17
class VitisUnrollCodeGen (UnrollCodeGenPass ):
18
18
def __init__ (self ):
19
- super ().__init__ ('Dense' , 'Conv1D' , 'Conv2D' )
19
+ super ().__init__ ('Dense' , 'Conv1D' , 'Conv2D' , 'PointwiseConv1D' , 'PointwiseConv2D' )
20
20
21
21
def get_stream_type_name (self , name : str ) -> str :
22
22
return f'{ name } ::value_type'
@@ -66,7 +66,12 @@ class VitisConvPreTemplate(OptimizerPass):
66
66
def match (self , node : Layer ):
67
67
if node .get_attr ('implementation' ) != 'linebuffer' :
68
68
return False
69
- return node .get_attr ('unrolled_codegen' ) and node .class_name in ('Conv1D' , 'Conv2D' )
69
+ return node .get_attr ('unrolled_codegen' ) and node .class_name in (
70
+ 'Conv1D' ,
71
+ 'Conv2D' ,
72
+ 'PointwiseConv1D' ,
73
+ 'PointwiseConv2D' ,
74
+ )
70
75
71
76
def transform (self , model : ModelGraph , node : Layer ):
72
77
io_type = model .config .get_config_value ("IOType" )
@@ -94,9 +99,13 @@ def latency_transform(self, model: ModelGraph, node: Layer):
94
99
node .attributes .attributes ['dense_config' ] = config_cpp
95
100
96
101
# override function_cpp
97
- if node .class_name == 'Conv1D' :
102
+ class_name = node .class_name
103
+ if class_name .startswith ('Pointwise' ):
104
+ class_name = class_name [9 :]
105
+
106
+ if class_name == 'Conv1D' :
98
107
fn_name = f'conv_1d<config{ node .index } >'
99
- elif node . class_name == 'Conv2D' :
108
+ elif class_name == 'Conv2D' :
100
109
fn_name = f'conv_2d<config{ node .index } >'
101
110
else :
102
111
raise ValueError (f'Unsupported layer type { node .class_name } ' )
@@ -107,7 +116,7 @@ def latency_transform(self, model: ModelGraph, node: Layer):
107
116
include_headers = [
108
117
'nnet_utils/nnet_unrolled.h' ,
109
118
'nnet_utils/nnet_dense_latency.h' ,
110
- f'nnet_utils/nnet_{ node . class_name .lower ()} .h' ,
119
+ f'nnet_utils/nnet_{ class_name .lower ()} .h' ,
111
120
'nnet_utils/nnet_conv_stream.h' , # some properties defined in config need this
112
121
]
113
122
node .attributes .attributes ['include_header' ] = include_headers
0 commit comments