-
Notifications
You must be signed in to change notification settings - Fork 462
Fix pooling layers when padding is applied from the left/top #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
To me it looks like this test failure is unrelated to the changes in this PR. It looks like tensorflow was not correctly set up:
|
Hm it seems So I think this means we should explicitly add |
@JanFSchulte Can you add a test case (test_pooling.py) that checks for this? You can use the existing test for global pooling as a starting point, just chenge a few things. Test it with all three backends and both io_parallel and io_stream. You may find the results interesting 😉 |
I think we didn't do it was because in old times we were thinking that we would want to support having pytorch and onnx converters if TF wasn't installed and the other way around. This was never fully implemented, but perhaps it is not a nonsense idea. I'm surprised that qkeras and its dependencies (tfmot and keras-tuner) don't explicitly require tensorflow. Maybe there's some reason for that? I also see that in the hls4ml-tutorial environment we install tensorflow explicitly. Can we do the same for sphinx environment? Or we revert the cleanup change that caused it. |
I have added the fixes to Vitis and also fixed another bug in the case of average pooling. Tests are added, but only for the case of |
…hinelearning#757) * fix pooling layers when padding is applied from the left/top * run pre-commit * add fixes for vitis, fix average pooling, add tests * remove changes from pytorch parser * diff clean * Update requirements.txt --------- Co-authored-by: Javier Duarte <[email protected]>
…hinelearning#757) * fix pooling layers when padding is applied from the left/top * run pre-commit * add fixes for vitis, fix average pooling, add tests * remove changes from pytorch parser * diff clean * Update requirements.txt --------- Co-authored-by: Javier Duarte <[email protected]>
When padding is applied from the left (or top and left in the 2D case), the indices to the entries of the data array currently point to the wrong elements when the pools are filled because the shift in indices due to the padding is not taken into account. In Keras this only happens when the pool_size is set above 2 so that the padding in the case of
same
padding is not only applied to the right but also to the left side of the input.This PR has a simple fix by correctly shifting the indices to take into account the offset. Is transparent in case there is no padding from the left/top required.
Additionally, when the zero padding is removed from the result in case of AveragePooling, this was not done correctly because the division of 2 integers always resulted in an integer result, so that the result was always rescaled with a factor of 1. This has been fixed by casting the numerator to
data_T
.Type of change
For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.
Note: Please delete options that are not relevant.
Tests
Problem can be reproduced and fix verified with this small script:
A test was added in
test/pytest/test_pooling.h
to verify the fixes.Checklist
pre-commit
on the files I edited or added.