Skip to content

Commit 219c7de

Browse files
vloncarjmduarte
andauthored
Support UpSampling1D (#475)
* Support UpSampling1D * Proper output directory for upsampling tests Co-authored-by: Javier Duarte <[email protected]>
1 parent bc51a44 commit 219c7de

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

hls4ml/converters/keras/reshape.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,34 @@ def parse_reshape_layer(keras_layer, input_names, input_shapes, data_reader, con
2727

2828
return layer, output_shape
2929

30+
@keras_handler('UpSampling1D')
31+
def parse_upsampling1d_layer(keras_layer, input_names, input_shapes, data_reader, config):
32+
assert('UpSampling' in keras_layer['class_name'])
33+
34+
layer = parse_default_keras_layer(keras_layer, input_names)
35+
36+
layer['in_height'] = 1
37+
(
38+
layer['in_width'],
39+
layer['n_chan']
40+
) = parse_data_format(input_shapes[0], layer['data_format'])
41+
42+
layer['algorithm'] = 'nearest'
43+
44+
layer['width_factor'] = keras_layer['config']['size']
45+
46+
layer['out_height'] = 1
47+
layer['out_width'] = layer['in_width'] * layer['width_factor']
48+
49+
if layer['data_format'] == 'channels_first':
50+
output_shape = [input_shapes[0][0], layer['n_chan'], layer['out_width']]
51+
else:
52+
output_shape = [input_shapes[0][0], layer['out_width'], layer['n_chan']]
53+
54+
return layer, output_shape
3055

3156
@keras_handler('UpSampling2D')
32-
def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader, config):
57+
def parse_upsampling2d_layer(keras_layer, input_names, input_shapes, data_reader, config):
3358
assert('UpSampling2D' in keras_layer['class_name'])
3459

3560
layer = parse_default_keras_layer(keras_layer, input_names)

hls4ml/model/layers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,12 @@ def initialize(self):
783783
class Resize(Layer):
784784
def initialize(self):
785785
inp = self.get_input_variable()
786-
shape = [self.get_attr('out_height'), self.get_attr('out_width'), self.get_attr('n_chan')]
787-
dims = ['OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index), 'N_CHAN_{}'.format(self.index)]
786+
if len(inp.shape) == 2: # 1D -> width + chan
787+
shape = [self.get_attr('out_width'), self.get_attr('n_chan')]
788+
dims = ['OUT_WIDTH_{}'.format(self.index), 'N_CHAN_{}'.format(self.index)]
789+
elif len(inp.shape) == 3: # 2D -> height + width + chan
790+
shape = [self.get_attr('out_height'), self.get_attr('out_width'), self.get_attr('n_chan')]
791+
dims = ['OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index), 'N_CHAN_{}'.format(self.index)]
788792
self.add_output_variable(shape, dims, precision=inp.type.precision)
789793

790794
class Transpose(Layer):
@@ -1012,6 +1016,7 @@ def _initialize_transforms(self):
10121016
'Dot' : Dot,
10131017
'Concatenate' : Concatenate,
10141018
'Resize' : Resize,
1019+
'UpSampling1D' : Resize,
10151020
'UpSampling2D' : Resize,
10161021
'Transpose' : Transpose,
10171022
'GarNet' : GarNet,

hls4ml/utils/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,21 @@ def config_from_keras_model(model, granularity='model', default_precision='ap_fi
100100
model_arch = json.loads(model.to_json())
101101

102102
#Define supported layers
103-
core_layers = ['InputLayer', 'Dropout', 'Flatten', 'Reshape', 'Permute', 'UpSampling2D']
103+
core_layers = ['InputLayer', 'Dropout', 'Flatten', 'Reshape', 'Permute']
104104
dense_layers = ['Dense', 'BinaryDense', 'TernaryDense']
105105
conv_layers = ['Conv1D', 'Conv2D', 'BinaryConv2D']
106106
pooling_layers = ['MaxPooling1D', 'MaxPooling2D', 'GlobalMaxPooling1D', 'GlobalMaxPooling2D', 'AveragePooling1D', 'AveragePooling2D', 'GlobalAveragePooling1D', 'GlobalAveragePooling2D']
107107
norm_layers = ['BatchNormalization']
108108
activation_layers = ['Activation', 'LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU']
109109
merge_layers = ['Add', 'Subtract', 'Multiply', 'Average', 'Maximum', 'Minimum', 'Concatenate', 'Dot']
110110
qkeras_layers = ['QDense', 'QActivation', 'QConv1D', 'QConv2D', 'QBatchNormalization', 'QConv2DBatchnorm']
111+
upsampling_layers = ['UpSampling1D', 'UpSampling2D']
111112
reshaping_layers = ['ZeroPadding1D', 'ZeroPadding2D']
112113
graph_layers = ['GarNet', 'GarNetStack']
113114
#Define layers to skip because they're not configurable or not converted to HLS
114115
skip_layers = ['Dropout', 'Flatten', 'Reshape', 'Permute']
115116
#All supported layers
116-
supported_layers = core_layers + dense_layers + conv_layers + pooling_layers + norm_layers + activation_layers + merge_layers + qkeras_layers + reshaping_layers + graph_layers + skip_layers
117+
supported_layers = core_layers + dense_layers + conv_layers + pooling_layers + norm_layers + activation_layers + merge_layers + qkeras_layers + upsampling_layers + reshaping_layers + graph_layers + skip_layers
117118

118119
keras_layer_config = None
119120
if model_arch['class_name'] == 'Sequential':

test/pytest/test_upsampling.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pytest
2+
from tensorflow.keras.models import Sequential
3+
from tensorflow.keras.layers import UpSampling1D, UpSampling2D
4+
import numpy as np
5+
import hls4ml
6+
from pathlib import Path
7+
8+
test_root_path = Path(__file__).parent
9+
10+
in_height = 6
11+
in_width = 8
12+
in_feat = 4
13+
14+
size = 2
15+
atol = 5e-3
16+
17+
@pytest.fixture(scope='module')
18+
def data_1d():
19+
X = np.random.rand(100, in_width, in_feat)
20+
return X
21+
22+
@pytest.fixture(scope='module')
23+
def data_2d():
24+
X = np.random.rand(100, in_height, in_width, in_feat)
25+
return X
26+
27+
28+
@pytest.fixture(scope='module')
29+
def keras_model_1d():
30+
model = Sequential()
31+
model.add(UpSampling1D(input_shape=(in_width, in_feat), size=size))
32+
model.compile()
33+
return model
34+
35+
@pytest.fixture(scope='module')
36+
def keras_model_2d():
37+
model = Sequential()
38+
model.add(UpSampling2D(input_shape=(in_height, in_width, in_feat), size=(size, size)))
39+
model.compile()
40+
return model
41+
42+
43+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
44+
@pytest.mark.parametrize('model_type', ['1d', '2d'])
45+
def test_upsampling(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type, io_type):
46+
if model_type == '1d':
47+
model = keras_model_1d
48+
data = data_1d
49+
else:
50+
model = keras_model_2d
51+
data = data_2d
52+
53+
config = hls4ml.utils.config_from_keras_model(model,
54+
default_precision='ap_fixed<32,1>',
55+
granularity='name')
56+
odir = str(test_root_path / f'hls4mlprj_upsampling_{model_type}_{io_type}')
57+
hls_model = hls4ml.converters.convert_from_keras_model(model,
58+
hls_config=config,
59+
io_type=io_type,
60+
output_dir=odir,
61+
part='xcvu9p-flgb2104-2-i')
62+
hls_model.compile()
63+
64+
# Predict
65+
y_keras = model.predict(data).flatten()
66+
y_hls = hls_model.predict(data).flatten()
67+
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)

0 commit comments

Comments
 (0)