Skip to content

Commit 91aa30a

Browse files
committed
Quartus CNN test fixes
1 parent 2f73598 commit 91aa30a

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
namespace nnet {
88

9-
enum class conv1d_implementation {combination, im2col, winograd};
10-
119
struct conv1d_config {
1210
// I/O sizes
1311
static const unsigned in_width = 10;

hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
namespace nnet {
88

9+
enum class conv1d_implementation {combination, im2col, winograd};
10+
911
// ****************************************************************
1012
// im2col - General-purpose 1D Convolution algorithm
1113
// ****************************************************************

hls4ml/writer/quartus_writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import print_function
22
import tarfile
3-
from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm
3+
from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, Dense
44
import yaml
55
from shutil import copyfile, copytree, rmtree
66
import numpy as np
@@ -78,7 +78,7 @@ def print_array_to_cpp(self, var, layer, odir):
7878
weight_size = layer.get_attr('impl_filt_height') * layer.get_attr('impl_filt_width') * layer.get_attr('n_filt') * layer.get_attr('n_chan')
7979
elif isinstance(layer, (Conv1D)):
8080
weight_size = layer.get_attr('impl_filt_width') * layer.get_attr('n_filt') * layer.get_attr('n_chan')
81-
else:
81+
elif isinstance(layer, (Dense)):
8282
weight_size = layer.get_attr('n_in') * layer.get_attr('n_out')
8383

8484
if (rf == 1 or var.name[0] == 'b' or weight_size <= 2048
@@ -831,7 +831,7 @@ def write_nnet_utils(self, model):
831831

832832
def __get_table_size(self, model, activation):
833833
for layer in model.get_layers():
834-
if layer.get_attr('activation') == activation or layer.get_attr('recurrent_activation') == activation and layer.get_attr('table_size') is not None:
834+
if (layer.get_attr('activation') == activation or layer.get_attr('recurrent_activation') == activation) and layer.get_attr('table_size') is not None:
835835
return int(layer.get_attr('table_size'))
836836
return 1024
837837

0 commit comments

Comments
 (0)