Skip to content

Commit cfd4b60

Browse files
committed
support pointwise conv layers
1 parent 3898a8d commit cfd4b60

File tree

1 file changed

+14
-5
lines changed
  • hls4ml/optimization/fused_dotp/optimizer_pass

1 file changed

+14
-5
lines changed

hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class VitisUnrollCodeGen(UnrollCodeGenPass):
1818
def __init__(self):
19-
super().__init__('Dense', 'Conv1D', 'Conv2D')
19+
super().__init__('Dense', 'Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D')
2020

2121
def get_stream_type_name(self, name: str) -> str:
2222
return f'{name}::value_type'
@@ -66,7 +66,12 @@ class VitisConvPreTemplate(OptimizerPass):
6666
def match(self, node: Layer):
6767
if node.get_attr('implementation') != 'linebuffer':
6868
return False
69-
return node.get_attr('unrolled_codegen') and node.class_name in ('Conv1D', 'Conv2D')
69+
return node.get_attr('unrolled_codegen') and node.class_name in (
70+
'Conv1D',
71+
'Conv2D',
72+
'PointwiseConv1D',
73+
'PointwiseConv2D',
74+
)
7075

7176
def transform(self, model: ModelGraph, node: Layer):
7277
io_type = model.config.get_config_value("IOType")
@@ -94,9 +99,13 @@ def latency_transform(self, model: ModelGraph, node: Layer):
9499
node.attributes.attributes['dense_config'] = config_cpp
95100

96101
# override function_cpp
97-
if node.class_name == 'Conv1D':
102+
class_name = node.class_name
103+
if class_name.startswith('Pointwise'):
104+
class_name = class_name[9:]
105+
106+
if class_name == 'Conv1D':
98107
fn_name = f'conv_1d<config{node.index}>'
99-
elif node.class_name == 'Conv2D':
108+
elif class_name == 'Conv2D':
100109
fn_name = f'conv_2d<config{node.index}>'
101110
else:
102111
raise ValueError(f'Unsupported layer type {node.class_name}')
@@ -107,7 +116,7 @@ def latency_transform(self, model: ModelGraph, node: Layer):
107116
include_headers = [
108117
'nnet_utils/nnet_unrolled.h',
109118
'nnet_utils/nnet_dense_latency.h',
110-
f'nnet_utils/nnet_{node.class_name.lower()}.h',
119+
f'nnet_utils/nnet_{class_name.lower()}.h',
111120
'nnet_utils/nnet_conv_stream.h', # some properties defined in config need this
112121
]
113122
node.attributes.attributes['include_header'] = include_headers

0 commit comments

Comments
 (0)