Skip to content

Commit ba3902f

Browse files
jmduartethesps
andcommitted
Reshape fixes: don't repack stream for flatten; remove final reshape (#443)
* fix 2 reshape issues: don't reshape streams for flatten and remove final reshape * Add a test for a model with Reshape as the final layer * swap * only remove for io_parallel; warn for both io_parallel and io_stream Co-authored-by: Sioni Summers <[email protected]>
1 parent d0ff8ba commit ba3902f

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

hls4ml/model/optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from hls4ml.model.optimizer.passes.conv_same_pad import InsertZeroPaddingBeforeConv2D
1313
from hls4ml.model.optimizer.passes.pointwise import OptimizePointwiseConv
1414
from hls4ml.model.optimizer.passes.clone import CloneOutput
15-
from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, BroadcastStream
15+
from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, BroadcastStream, RemoveFinalReshape
1616
from hls4ml.model.optimizer.passes.transpose_opt import RemoveUselessTranspose
1717
from hls4ml.model.optimizer.passes.multi_dense import ReplaceMultidimensionalDenseWithConv
1818

@@ -40,6 +40,7 @@
4040
register_pass('conv2d_same_pad', InsertZeroPaddingBeforeConv2D)
4141
register_pass('optimize_pointwise_conv', OptimizePointwiseConv)
4242
register_pass('clone_output', CloneOutput)
43+
register_pass('remove_final_reshape', RemoveFinalReshape)
4344
register_pass('reshape_stream', ReshapeStream)
4445
register_pass('remove_useless_transpose', RemoveUselessTranspose)
4546
register_pass('replace_multidense_conv', ReplaceMultidimensionalDenseWithConv)

hls4ml/model/optimizer/passes/repack_stream.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def config_cpp(self):
7171
class ReshapeStream(OptimizerPass):
7272
''' Repacks stream for Reshape layer '''
7373
def match(self, node):
74-
return node.__class__.__name__ == 'Reshape'
74+
# do not run optimizer pass for a flatten layer (1 output dimension)
75+
return node.__class__.__name__ == 'Reshape' and len(node.get_output_variable().shape) > 1
7576

7677
def transform(self, model, node):
7778
if model.config.backend.name not in ['Vivado', 'VivadoAccelerator'] or \
@@ -121,3 +122,19 @@ def transform(self, model, node):
121122
node.inputs[idx] = brdcst_out
122123

123124
return True
125+
126+
class RemoveFinalReshape(OptimizerPass):
127+
''' Remove reshape if final layer '''
128+
def match(self, node):
129+
# match if reshape is final node
130+
return node.__class__.__name__ == 'Reshape' and not node.get_output_nodes()
131+
132+
def transform(self, model, node):
133+
if model.config.get_config_value('IOType') == 'io_parallel':
134+
print('WARNING: Final layer is a Reshape, which does not affect the output for io_parallel; removing it')
135+
# remove, but don't rewire because it's the output layer
136+
model.remove_node(node, rewire=False)
137+
return True
138+
elif model.config.get_config_value('IOType') == 'io_stream':
139+
print('WARNING: Final layer is a Reshape, which may incur a large resource cost for io_stream; consider removing it')
140+
return False

test/pytest/test_graph.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hls4ml
22
import numpy as np
33
import pytest
4+
import tensorflow as tf
45

56
class Reader:
67
def get_weights_data(self, name, var):
@@ -107,3 +108,34 @@ def test_graph_branch(iotype, batch):
107108
y = model.predict([X0, X1]).reshape(y_expected.shape)
108109
# check the output
109110
np.testing.assert_allclose(y, y_expected, rtol=1, atol=2**-16)
111+
112+
@pytest.mark.parametrize('iotype', ['io_parallel', 'io_stream'])
113+
def test_final_reshape(iotype):
114+
''' Test case for a model with a Reshape as the final layer '''
115+
inputs = tf.keras.layers.Input(shape=(1,1,1)) # 1 input pixel
116+
conv = tf.keras.layers.Conv2D(6,1) # 6 filters, 1x1 kernel
117+
x = conv(inputs)
118+
conv.set_weights([np.linspace(1,6,6).reshape(1,1,1,6), np.zeros(6)]) # ascending int weights, 0 bias
119+
x = tf.keras.layers.Reshape((3,2))(x) # reshape the (1,1,6) output to (3,2)
120+
model = tf.keras.models.Model(inputs=inputs, outputs=x)
121+
122+
# create the HLSModel
123+
config = hls4ml.utils.config_from_keras_model(model, granularity='model')
124+
hls_model = hls4ml.converters.convert_from_keras_model(model,
125+
output_dir=f'hls4mlprj_graph_final_reshape_{iotype}',
126+
backend='Vivado',
127+
io_type = iotype,
128+
hls_config=config)
129+
hls_model.compile()
130+
131+
# Test on ascending integers. The weights mean that each output pixel/neuron has
132+
# a different value
133+
X = np.linspace(-4,4,9).reshape(9,1,1,1)
134+
y = model.predict(X)
135+
y_hls = hls_model.predict(X).reshape(y.shape)
136+
# because of integer inputs and integer weights, we can expect exact matching
137+
np.testing.assert_allclose(y, y_hls, rtol=0)
138+
139+
140+
141+

0 commit comments

Comments
 (0)