Skip to content

Commit c49ffc0

Browse files
committed
Test case for pointwise conv1d/2d
1 parent 5ecab69 commit c49ffc0

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

test/pytest/test_pointwiseconv.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
import hls4ml
3+
import tensorflow as tf
4+
import numpy as np
5+
from pathlib import Path
6+
from tensorflow.keras.layers import Conv1D, Conv2D
7+
from tensorflow.keras import backend as K
8+
9+
test_root_path = Path(__file__).parent
10+
11+
padds_options = ['same', 'valid']
12+
chans_options = ['channels_last']
13+
io_type_options = ['io_parallel', 'io_stream']
14+
strides1d_options = [(1,), (2,)]
15+
strides2d_options = [(1, 1), (2, 2)]
16+
strategy_options = ['Latency', 'Resource']
17+
18+
@pytest.mark.parametrize("chans", chans_options)
19+
@pytest.mark.parametrize("padds", padds_options)
20+
@pytest.mark.parametrize("strides", strides1d_options)
21+
@pytest.mark.parametrize("io_type", io_type_options)
22+
@pytest.mark.parametrize("strategy", strategy_options)
23+
def test_pointwiseconv1d(chans, padds, strides, io_type, strategy):
24+
model = tf.keras.models.Sequential()
25+
input_shape = (28, 3)
26+
model.add(Conv1D(filters=32,
27+
kernel_size=(1,),
28+
strides=strides,
29+
padding=padds,
30+
input_shape=input_shape,
31+
kernel_initializer='normal',
32+
use_bias=False,
33+
data_format=chans
34+
))
35+
36+
model.compile(optimizer='adam', loss='mse')
37+
X_input = np.random.rand(100, *input_shape)
38+
keras_prediction = model.predict(X_input)
39+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>')
40+
config['Model']['Strategy'] = strategy
41+
output_dir = str(test_root_path / 'hls4mlprj_pointwise1d_{}_strides_{}_{}_padding_{}_{}'.format(chans, strides[0], padds, io_type, strategy))
42+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, io_type=io_type)
43+
hls_model.compile()
44+
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
45+
46+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
47+
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
48+
49+
@pytest.mark.parametrize("chans", chans_options)
50+
@pytest.mark.parametrize("padds", padds_options)
51+
@pytest.mark.parametrize("strides", strides2d_options)
52+
@pytest.mark.parametrize("io_type", io_type_options)
53+
@pytest.mark.parametrize("strategy", strategy_options)
54+
def test_pointwiseconv2d(chans, padds, strides, io_type, strategy):
55+
model = tf.keras.models.Sequential()
56+
input_shape = (28, 28, 3)
57+
model.add(Conv2D(filters=32,
58+
kernel_size=(1, 1),
59+
strides=strides,
60+
padding=padds,
61+
input_shape=input_shape,
62+
kernel_initializer='normal',
63+
use_bias=False,
64+
data_format=chans
65+
))
66+
67+
model.compile(optimizer='adam', loss='mse')
68+
X_input = np.random.rand(100, *input_shape)
69+
keras_prediction = model.predict(X_input)
70+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>')
71+
config['Model']['Strategy'] = strategy
72+
stride_cfg = str(strides).replace(', ', '_').replace('(', '').replace(')', '')
73+
output_dir = str(test_root_path / 'hls4mlprj_pointwise2d_{}_strides_{}_{}_padding_{}_{}'.format(chans, stride_cfg, padds, io_type, strategy))
74+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, io_type=io_type)
75+
hls_model.compile()
76+
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
77+
78+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
79+
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)

0 commit comments

Comments
 (0)