|
| 1 | +import pytest |
| 2 | +from tensorflow.keras.models import Sequential |
| 3 | +from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D |
| 4 | +import numpy as np |
| 5 | +import hls4ml |
| 6 | + |
| 7 | + |
| 8 | +in_shape = 8 |
| 9 | +in_feat = 4 |
| 10 | +atol = 5e-3 |
| 11 | + |
| 12 | +@pytest.fixture(scope='module') |
| 13 | +def data(): |
| 14 | + X = np.random.rand(100, in_shape, in_feat) |
| 15 | + return X |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture(scope='module') |
| 19 | +def keras_model_max(): |
| 20 | + model = Sequential() |
| 21 | + model.add(GlobalMaxPooling1D(input_shape=(in_shape, in_feat))) |
| 22 | + model.compile() |
| 23 | + return model |
| 24 | + |
| 25 | +@pytest.fixture(scope='module') |
| 26 | +def keras_model_ave(): |
| 27 | + model = Sequential() |
| 28 | + model.add(GlobalAveragePooling1D(input_shape=(in_shape, in_feat))) |
| 29 | + model.compile() |
| 30 | + return model |
| 31 | + |
| 32 | + |
| 33 | +@pytest.mark.parametrize("configuration", [("ave", 'io_parallel'), |
| 34 | + ("ave", 'io_stream'), |
| 35 | + ("max", 'io_parallel'), |
| 36 | + ("max", 'io_stream')]) |
| 37 | +def test_global_pool1d(keras_model_max, keras_model_ave, data, configuration): |
| 38 | + model_type, io_type = configuration |
| 39 | + if model_type == "ave": |
| 40 | + model = keras_model_ave |
| 41 | + else: |
| 42 | + model = keras_model_max |
| 43 | + config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,1>', |
| 44 | + granularity='name') |
| 45 | + if model_type == "ave": |
| 46 | + config['LayerName']['global_average_pooling1d']['Precision']='ap_fixed<32,6>' |
| 47 | + |
| 48 | + hls_model = hls4ml.converters.convert_from_keras_model(model, |
| 49 | + hls_config=config, |
| 50 | + io_type=io_type, |
| 51 | + output_dir=f'hls4ml_globalplool1d_{model_type}_{io_type}', |
| 52 | + part='xcvu9p-flgb2104-2-i') |
| 53 | + hls_model.compile() |
| 54 | + |
| 55 | + |
| 56 | + # Predict |
| 57 | + y_keras = np.squeeze(model.predict(data)) |
| 58 | + y_hls = hls_model.predict(data) |
| 59 | + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) |
0 commit comments