-
Notifications
You must be signed in to change notification settings - Fork 5.9k
My unpool max 2d #5826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sweetsky0901
merged 38 commits into
PaddlePaddle:develop
from
sweetsky0901:my_unpool_max_2d
Nov 29, 2017
Merged
My unpool max 2d #5826
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
fef617a
for resolve conflicts
9127840
Merge branch 'sweetsky0901-my_maxout_op' into develop
4748073
paddle/operators/math/CMakeLists.txt maybe del sequence_pooling and a…
6665c49
Merge branch 'sweetsky0901-my_maxout_op2' into develop
bc45335
add unpool
f638f91
merge cmakelist
45a8c9d
add unpool2d make ok
ab03daa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
200f07c
add test
822f283
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
90f664d
test unpool ok cpu
abb3357
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
e2a5905
gpu test ok unpool2dmax
8ba8237
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
47bd0bb
del printf
0112c5d
format code
e553d57
format test code
66b8436
modify for code review by wangyi
ee4a5d2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
c218961
modify for code review by qingqing
a38bbc8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cfd7721
add unpool_op.h modify
27cf7f3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
20654cf
modify for type check rewrite
022b48e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
f9c2a5c
modify for code review zcd
57e68e5
modify for code review by qingqing 2nd
ee0a794
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
6fc9a9f
modify for del T2 and doc update
821899c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
d9673ca
format code
bd56138
format code
c52ed8d
format code
2d42fa7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
d2ee3c9
format code
3206094
format code
5b449b6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
4ffb73f
format ..
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/math/unpooling.h" | ||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
| template <typename T> | ||
| class Unpool2dMaxFunctor<platform::CPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, framework::Tensor* output) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output->dims()[1]; | ||
| const int output_height = output->dims()[2]; | ||
| const int output_width = output->dims()[3]; | ||
| int input_feasize = input_height * input_width; | ||
| int output_feasize = output_height * output_width; | ||
| const T* input_data = input.data<T>(); | ||
| const int* indices_data = indices.data<int>(); | ||
| T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
| for (int b = 0; b < batch_size; ++b) { | ||
| for (int c = 0; c < output_channels; ++c) { | ||
| for (int i = 0; i < input_feasize; ++i) { | ||
| int index = indices_data[i]; | ||
| PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); | ||
| output_data[index] = input_data[i]; | ||
| } | ||
| input_data += input_feasize; | ||
| indices_data += input_feasize; | ||
| output_data += output_feasize; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
| template <class T> | ||
| class Unpool2dMaxGradFunctor<platform::CPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor* input_grad) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output.dims()[1]; | ||
| const int output_height = output.dims()[2]; | ||
| const int output_width = output.dims()[3]; | ||
| int input_feasize = input_height * input_width; | ||
| int output_feasize = output_height * output_width; | ||
| const int* indices_data = indices.data<int>(); | ||
| const T* output_grad_data = output_grad.data<T>(); | ||
| T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| for (int b = 0; b < batch_size; ++b) { | ||
| for (int c = 0; c < output_channels; ++c) { | ||
| for (int i = 0; i < input_feasize; ++i) { | ||
| int index = indices_data[i]; | ||
| PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); | ||
| input_grad_data[i] = output_grad_data[index]; | ||
| } | ||
| input_grad_data += input_feasize; | ||
| indices_data += input_feasize; | ||
| output_grad_data += output_feasize; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
| template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>; | ||
| template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>; | ||
| template class Unpool2dMaxFunctor<platform::CPUPlace, float>; | ||
| template class Unpool2dMaxFunctor<platform::CPUPlace, double>; | ||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| /* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/math/unpooling.h" | ||
| #include "paddle/platform/cuda_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
| template <typename T> | ||
| __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, | ||
| const int* indices_data, | ||
| const int input_height, const int input_width, | ||
| const int channels, T* output_data, | ||
| const int output_height, | ||
| const int output_width) { | ||
| int in_n_stride = input_height * input_width * channels; | ||
| int in_c_stride = input_height * input_width; | ||
| int out_n_stride = output_height * output_width * channels; | ||
| int out_c_stride = output_height * output_width; | ||
| int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int offset = blockDim.x * gridDim.x; | ||
| for (int i = index; i < nthreads; i += offset) { | ||
| int bidx = i / in_n_stride; | ||
| int boffset = i % in_n_stride; | ||
| int cidx = boffset / in_c_stride; | ||
| int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
| int out_index = indices_data[i]; | ||
| PADDLE_ASSERT(out_index < out_c_stride); | ||
| output_data[out_offset + out_index] = input_data[i]; | ||
| } | ||
| } | ||
| template <typename T> | ||
| __global__ void KernelUnpool2dMaxGrad( | ||
| const int nthreads, const T* input_data, const int* indices_data, | ||
| const int input_height, const int input_width, const int channels, | ||
| const T* output_data, const T* output_grad, const int output_height, | ||
| const int output_width, T* input_grad) { | ||
| int in_n_stride = input_height * input_width * channels; | ||
| int in_c_stride = input_height * input_width; | ||
| int out_n_stride = output_height * output_width * channels; | ||
| int out_c_stride = output_height * output_width; | ||
| int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int offset = blockDim.x * gridDim.x; | ||
| for (int i = index; i < nthreads; i += offset) { | ||
| int bidx = i / in_n_stride; | ||
| int boffset = i % in_n_stride; | ||
| int cidx = boffset / in_c_stride; | ||
| int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
| int out_index = indices_data[i]; | ||
| PADDLE_ASSERT(out_index < out_c_stride); | ||
| input_grad[i] = output_grad[out_offset + out_index]; | ||
| } | ||
| } | ||
| /* | ||
| * All tensors are in NCHW format. | ||
| */ | ||
| template <typename T> | ||
| class Unpool2dMaxFunctor<platform::GPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, framework::Tensor* output) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output->dims()[1]; | ||
| const int output_height = output->dims()[2]; | ||
| const int output_width = output->dims()[3]; | ||
| const T* input_data = input.data<T>(); | ||
| const int* indices_data = indices.data<int>(); | ||
| T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
| int threads = 1024; | ||
| int grid = (input.numel() + threads - 1) / threads; | ||
| KernelUnpool2dMax< | ||
| T><<<grid, threads, 0, | ||
| reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
| .stream()>>>(input.numel(), input_data, indices_data, | ||
| input_height, input_width, output_channels, | ||
| output_data, output_height, output_width); | ||
| } | ||
| }; | ||
| /* | ||
| * All tensors are in NCHW format. | ||
| */ | ||
| template <typename T> | ||
| class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor* input_grad) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output.dims()[1]; | ||
| const int output_height = output.dims()[2]; | ||
| const int output_width = output.dims()[3]; | ||
| const T* input_data = input.data<T>(); | ||
| const int* indices_data = indices.data<int>(); | ||
| const T* output_data = output.data<T>(); | ||
| const T* output_grad_data = output_grad.data<T>(); | ||
| T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
| int threads = 1024; | ||
| int grid = (input.numel() + threads - 1) / threads; | ||
| KernelUnpool2dMaxGrad< | ||
| T><<<grid, threads, 0, | ||
| reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
| .stream()>>>(input.numel(), input_data, indices_data, | ||
| input_height, input_width, output_channels, | ||
| output_data, output_grad_data, output_height, | ||
| output_width, input_grad_data); | ||
| } | ||
| }; | ||
| template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; | ||
| template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>; | ||
| template class Unpool2dMaxFunctor<platform::GPUPlace, float>; | ||
| template class Unpool2dMaxFunctor<platform::GPUPlace, double>; | ||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include "paddle/framework/tensor.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
| template <typename Place, typename T> | ||
| class Unpool2dMaxFunctor { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, framework::Tensor* output); | ||
| }; | ||
| template <typename Place, class T> | ||
| class Unpool2dMaxGradFunctor { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor* input_grad); | ||
| }; | ||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同CPU,也可以复用
MaxPool2dWithIndexGradFunctor,可以去掉 Unpool2dMaxFunctor。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样会不会让代码结构上显得有点乱呢,我不太建议这么弄