Skip to content

Fix pooling layers when padding is applied from the left/top #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 14, 2023

Conversation

JanFSchulte
Copy link
Contributor

@JanFSchulte JanFSchulte commented Apr 10, 2023

When padding is applied from the left (or top and left in the 2D case), the indices to the entries of the data array currently point to the wrong elements when the pools are filled because the shift in indices due to the padding is not taken into account. In Keras this only happens when the pool_size is set above 2 so that the padding in the case of same padding is not only applied to the right but also to the left side of the input.

This PR has a simple fix by correctly shifting the indices to take into account the offset. Is transparent in case there is no padding from the left/top required.

Additionally, when the zero padding is removed from the result in case of AveragePooling, this was not done correctly because the division of 2 integers always resulted in an integer result, so that the result was always rescaled with a factor of 1. This has been fixed by casting the numerator to data_T.

Type of change

For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.

Note: Please delete options that are not relevant.

  • Bug fix (non-breaking change that fixes an issue)

Tests

Problem can be reproduced and fix verified with this small script:

import math
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import MaxPooling1D

import hls4ml

input_shape = (124, 5)
X_input = np.random.rand(100, *input_shape)

keras_model = tf.keras.models.Sequential()
keras_model.add(MaxPooling1D(pool_size = 3, padding="same", input_shape=input_shape))
keras_model.compile()

hls_cfg = hls4ml.utils.config_from_keras_model(keras_model)
output_dir = "test_keras"

hls_model = hls4ml.converters.convert_from_keras_model(
    keras_model, hls_config=hls_cfg, output_dir=output_dir, backend="Vivado"
)
hls_model.compile()

# Verify accuracy
keras_prediction = keras_model.predict(X_input)
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)

np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=3e-2)

A test was added in test/pytest/test_pooling.h to verify the fixes.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@jmitrevs jmitrevs added the please test Trigger testing by creating local PR branch label Apr 11, 2023
@JanFSchulte
Copy link
Contributor Author

To me it looks like this test failure is unrelated to the changes in this PR. It looks like tensorflow was not correctly set up:

File "/usr/local/lib/python3.10/site-packages/qkeras/__init__.py", line 20, in <module>
    import tensorflow as tf
ModuleNotFoundError: No module named 'tensorflow'

@jmduarte
Copy link
Member

Hm it seems tensorflow is not a requirement for qkeras? I always thought it was: https://github.com/google/qkeras/blob/v0.9.0/setup.py#L41-L50

So I think this means we should explicitly add tensorflow as a requirement to hls4ml @vloncar @jmitrevs

@vloncar
Copy link
Contributor

vloncar commented Apr 13, 2023

@JanFSchulte Can you add a test case (test_pooling.py) that checks for this? You can use the existing test for global pooling as a starting point, just chenge a few things. Test it with all three backends and both io_parallel and io_stream. You may find the results interesting 😉

@vloncar
Copy link
Contributor

vloncar commented Apr 13, 2023

So I think this means we should explicitly add tensorflow as a requirement to hls4ml @vloncar @jmitrevs

I think we didn't do it was because in old times we were thinking that we would want to support having pytorch and onnx converters if TF wasn't installed and the other way around. This was never fully implemented, but perhaps it is not a nonsense idea. I'm surprised that qkeras and its dependencies (tfmot and keras-tuner) don't explicitly require tensorflow. Maybe there's some reason for that? I also see that in the hls4ml-tutorial environment we install tensorflow explicitly. Can we do the same for sphinx environment? Or we revert the cleanup change that caused it.

@JanFSchulte
Copy link
Contributor Author

I have added the fixes to Vitis and also fixed another bug in the case of average pooling. Tests are added, but only for the case of io_parallel, since padding is not supported in case of io_stream.

@vloncar vloncar added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Apr 13, 2023
@jmduarte jmduarte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Apr 14, 2023
@jmduarte jmduarte self-requested a review April 14, 2023 02:34
@jmduarte jmduarte merged commit 8eb29fc into fastmachinelearning:main Apr 14, 2023
JanFSchulte added a commit to JanFSchulte/hls4ml that referenced this pull request May 23, 2023
…hinelearning#757)

* fix pooling layers when padding is applied from the left/top

* run pre-commit

* add fixes for vitis, fix average pooling, add tests

* remove changes from pytorch parser

* diff clean

* Update requirements.txt

---------

Co-authored-by: Javier Duarte <[email protected]>
calad0i pushed a commit to calad0i/hls4ml that referenced this pull request Jul 1, 2023
…hinelearning#757)

* fix pooling layers when padding is applied from the left/top

* run pre-commit

* add fixes for vitis, fix average pooling, add tests

* remove changes from pytorch parser

* diff clean

* Update requirements.txt

---------

Co-authored-by: Javier Duarte <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants