Skip to content

Commit 1b42183

Browse files
committed
update test
1 parent e8f9c8d commit 1b42183

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

test/pytest/test_globalpooling1d.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,19 @@ def keras_model_ave():
3030
return model
3131

3232

33-
@pytest.mark.parametrize("configuration", [("ave", 'io_parallel'),
34-
("ave", 'io_stream'),
35-
("max", 'io_parallel'),
36-
("max", 'io_stream')])
37-
def test_global_pool1d(keras_model_max, keras_model_ave, data, configuration):
33+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
34+
@pytest.mark.parametrize('model_type', ['max', 'ave'])
35+
def test_global_pool1d(keras_model_max, keras_model_ave, data, model_type, io_type):
3836
model_type, io_type = configuration
39-
if model_type == "ave":
37+
if model_type == 'ave':
4038
model = keras_model_ave
4139
else:
4240
model = keras_model_max
43-
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,1>',
41+
config = hls4ml.utils.config_from_keras_model(model,
42+
default_precision='ap_fixed<32,1>',
4443
granularity='name')
45-
if model_type == "ave":
46-
config['LayerName']['global_average_pooling1d']['Precision']='ap_fixed<32,6>'
44+
if model_type == 'ave':
45+
config['LayerName']['global_average_pooling1d']['accum_t'] = 'ap_fixed<32,6>'
4746

4847
hls_model = hls4ml.converters.convert_from_keras_model(model,
4948
hls_config=config,

0 commit comments

Comments
 (0)