-
Notifications
You must be signed in to change notification settings - Fork 5.9k
My unpool max 2d #5826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
My unpool max 2d #5826
Changes from 23 commits
fef617a
9127840
4748073
6665c49
bc45335
f638f91
45a8c9d
ab03daa
200f07c
822f283
90f664d
abb3357
e2a5905
8ba8237
47bd0bb
0112c5d
e553d57
66b8436
ee4a5d2
c218961
a38bbc8
cfd7721
27cf7f3
20654cf
022b48e
f9c2a5c
57e68e5
ee0a794
6fc9a9f
821899c
d9673ca
bd56138
c52ed8d
2d42fa7
d2ee3c9
3206094
5b449b6
4ffb73f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/math/unpooling.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
|
|
||
| // All tensors are in NCHW format | ||
| template <typename T> | ||
| class Unpool2dMaxFunctor<platform::CPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| framework::Tensor * output) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output->dims()[1]; | ||
| const int output_height = output->dims()[2]; | ||
| const int output_width = output->dims()[3]; | ||
| int input_feasize = input_height * input_width; | ||
| int output_feasize = output_height * output_width; | ||
| const T* input_data = input.data<T>(); | ||
| const T * indices_data = indices.data<T>(); | ||
| T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
| for (int b = 0; b < batch_size; ++b) { | ||
| for (int c = 0; c < output_channels; ++c) { | ||
| for (int i = 0; i < input_feasize; ++i) { | ||
| int index = indices_data[i]; | ||
| PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个可以写在循环外面,将indices转为EigenMatrix或者EigenVector,之后调用maximum方法返回indices中最大元素,然后判断是否越界。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 调用maximum实际上又多了很多次比较,不太赞成这样的写法 |
||
| output_data[index] = input_data[i]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么不是 另外,Unpool2dMaxFunctor可以不用写,直接复用
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样会不会让代码结构上显得有点乱呢,我不太建议这么弄 |
||
| } | ||
| input_data += input_feasize; | ||
| indices_data += input_feasize; | ||
| output_data += output_feasize; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
|
|
||
|
|
||
| template <class T> | ||
| class Unpool2dMaxGradFunctor<platform::CPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor * input_grad) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output.dims()[1]; | ||
| const int output_height = output.dims()[2]; | ||
| const int output_width = output.dims()[3]; | ||
| int input_feasize = input_height * input_width; | ||
| int output_feasize = output_height * output_width; | ||
| const T* indices_data = indices.data<T>(); | ||
| const T* output_grad_data = output_grad.data<T>(); | ||
| T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| for (int b = 0; b < batch_size; ++b) { | ||
| for (int c = 0; c < output_channels; ++c) { | ||
| for (int i = 0; i < input_feasize; ++i) { | ||
| int index = indices_data[i]; | ||
| PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); | ||
| input_grad_data[i] = output_grad_data[index]; | ||
| } | ||
| input_grad_data += input_feasize; | ||
| indices_data += input_feasize; | ||
| output_grad_data += output_feasize; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>; | ||
| template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>; | ||
| template class Unpool2dMaxFunctor<platform::CPUPlace, float>; | ||
| template class Unpool2dMaxFunctor<platform::CPUPlace, double>; | ||
|
|
||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| /* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/math/unpooling.h" | ||
| #include "paddle/platform/cuda_helper.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
|
|
||
| template <typename T> | ||
| __global__ void KernelUnpool2dMax(const int nthreads, | ||
| const T* input_data, | ||
| const T* indices_data, | ||
| const int input_height, | ||
| const int input_width, | ||
| const int channels, | ||
| T* output_data, | ||
| const int output_height, | ||
| const int output_width) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please note the order of parameters. |
||
| int bsize = input_height * input_width * channels; | ||
| int csize = input_height * input_width; | ||
| int out_bsize = output_height * output_width * channels; | ||
| int out_csize = output_height * output_width; | ||
|
||
| int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int offset = blockDim.x * gridDim.x; | ||
| for (int i = index; i < nthreads; i += offset) { | ||
| int bidx = i / bsize; | ||
| int boffset = i % bsize; | ||
| int cidx = boffset / csize; | ||
| int out_offset = bidx * out_bsize + cidx * out_csize; | ||
| int out_index = indices_data[i]; | ||
| PADDLE_ASSERT(out_index < (output_height * output_width)); | ||
|
||
| output_data[out_offset + out_index] = input_data[i]; | ||
| } | ||
| } | ||
| template <typename T> | ||
| __global__ void KernelUnpool2dMaxGrad(const int nthreads, | ||
| const T* input_data, | ||
| const T* indices_data, | ||
| const int input_height, | ||
| const int input_width, | ||
| const int channels, | ||
| const T* output_data, | ||
| const T* output_grad, | ||
| const int output_height, | ||
| const int output_width, | ||
| T* input_grad) { | ||
| int bsize = input_height * input_width * channels; | ||
| int csize = input_height * input_width; | ||
| int out_bsize = output_height * output_width * channels; | ||
| int out_csize = output_height * output_width; | ||
|
||
| int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int offset = blockDim.x * gridDim.x; | ||
| for (int i = index; i < nthreads; i += offset) { | ||
| int bidx = i / bsize; | ||
| int boffset = i % bsize; | ||
| int cidx = boffset / csize; | ||
| int out_offset = bidx * out_bsize + cidx * out_csize; | ||
| int out_index = indices_data[i]; | ||
| PADDLE_ASSERT(out_index < (output_height * output_width)); | ||
|
||
| input_grad[i] = output_grad[out_offset + out_index]; | ||
| } | ||
| } | ||
| /* | ||
| * All tensors are in NCHW format. | ||
| */ | ||
| template <typename T> | ||
| class Unpool2dMaxFunctor<platform::GPUPlace, T> { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| framework::Tensor * output) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output->dims()[1]; | ||
| const int output_height = output->dims()[2]; | ||
| const int output_width = output->dims()[3]; | ||
| const T* input_data = input.data<T>(); | ||
| const T* indices_data = indices.data<T>(); | ||
| T* output_data = output->mutable_data<T>(context.GetPlace()); | ||
| int nthreads = batch_size * output_channels * input_height * input_width; | ||
| int blocks = (nthreads + 1024 - 1) / 1024; | ||
|
||
| dim3 threads(1024, 1); | ||
| dim3 grid(blocks, 1); | ||
|
|
||
| KernelUnpool2dMax< | ||
| T><<<grid, threads, 0, | ||
| reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
| .stream()>>>(nthreads, input_data, indices_data, | ||
| input_height, input_width, output_channels, | ||
| output_data, output_height, output_width); | ||
| } | ||
| }; | ||
| /* | ||
| * All tensors are in NCHW format. | ||
| */ | ||
| template <typename T> | ||
| class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同CPU,也可以复用
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样会不会让代码结构上显得有点乱呢,我不太建议这么弄 |
||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor * input_grad) { | ||
| const int batch_size = input.dims()[0]; | ||
| const int input_height = input.dims()[2]; | ||
| const int input_width = input.dims()[3]; | ||
| const int output_channels = output.dims()[1]; | ||
| const int output_height = output.dims()[2]; | ||
| const int output_width = output.dims()[3]; | ||
| const T* input_data = input.data<T>(); | ||
| const T* indices_data = indices.data<T>(); | ||
| const T* output_data = output.data<T>(); | ||
| const T* output_grad_data = output_grad.data<T>(); | ||
| T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); | ||
| int nthreads = batch_size * output_channels * input_height * input_width; | ||
| int blocks = (nthreads + 1024 - 1) / 1024; | ||
| dim3 threads(1024, 1); | ||
| dim3 grid(blocks, 1); | ||
|
||
|
|
||
| KernelUnpool2dMaxGrad< | ||
| T><<<grid, threads, 0, | ||
| reinterpret_cast<const platform::CUDADeviceContext&>(context) | ||
| .stream()>>>( | ||
| nthreads, input_data, indices_data, | ||
| input_height, input_width, output_channels, | ||
| output_data, output_grad_data, | ||
| output_height, output_width, | ||
| input_grad_data); | ||
| } | ||
| }; | ||
|
|
||
| template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; | ||
| template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>; | ||
|
|
||
| template class Unpool2dMaxFunctor<platform::GPUPlace, float>; | ||
| template class Unpool2dMaxFunctor<platform::GPUPlace, double>; | ||
|
|
||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include "paddle/framework/tensor.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| namespace math { | ||
|
|
||
| template <typename Place, typename T> | ||
|
|
||
| class Unpool2dMaxFunctor { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| framework::Tensor * output); | ||
| }; | ||
|
|
||
| template <typename Place, class T> | ||
| class Unpool2dMaxGradFunctor { | ||
| public: | ||
| void operator()(const platform::DeviceContext& context, | ||
| const framework::Tensor& input, | ||
| const framework::Tensor& indices, | ||
| const framework::Tensor& output, | ||
| const framework::Tensor& output_grad, | ||
| framework::Tensor * input_grad); | ||
| }; | ||
| } // namespace math | ||
| } // namespace operators | ||
| } // namespace paddle |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You specify indices type to
int16in https://github.com/sweetsky0901/Paddle/blob/27cf7f3376e2005522dbba8c964cb2e449b46380/python/paddle/v2/fluid/tests/test_unpool_op.py#L56 so, there must be wrong usingconst T * indices_data = indices.data<T>();here, also, the paddle does not support int16 currently. So, you must specify int32 when test, and useconst int * indices_data = indices.data<int>();or add a new parameter type to the template, like hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
系统默认对输入类型做了检查,必须一样,如果不一样,需要重写
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Inputframework::Tensor("X")->type()),
ctx.device_context());
}
之前为了逃避这个检查,用float传递的index,然后到具体计算再强转的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是这样的
在注册Op的时候一般注册两个(double/float),那么在op运行的时候到底是用哪个呢(double的还是float),这时候需要调用
GetKernelType函数来判断,如果X是float,那就调float的Op,如果是double,那就调double的OpThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请问这些有文档吗?