Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
fef617a
for resolve conflicts
Nov 11, 2017
9127840
Merge branch 'sweetsky0901-my_maxout_op' into develop
Nov 11, 2017
4748073
paddle/operators/math/CMakeLists.txt maybe del sequence_pooling and a…
Nov 11, 2017
6665c49
Merge branch 'sweetsky0901-my_maxout_op2' into develop
Nov 11, 2017
bc45335
add unpool
Nov 21, 2017
f638f91
merge cmakelist
Nov 21, 2017
45a8c9d
add unpool2d make ok
Nov 21, 2017
ab03daa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 21, 2017
200f07c
add test
Nov 21, 2017
822f283
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 21, 2017
90f664d
test unpool ok cpu
Nov 22, 2017
abb3357
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 22, 2017
e2a5905
gpu test ok unpool2dmax
Nov 22, 2017
8ba8237
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 22, 2017
47bd0bb
del printf
Nov 22, 2017
0112c5d
format code
Nov 22, 2017
e553d57
format test code
Nov 22, 2017
66b8436
modify for code review by wangyi
Nov 23, 2017
ee4a5d2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 23, 2017
c218961
modify for code review by qingqing
Nov 26, 2017
a38bbc8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 26, 2017
cfd7721
add unpool_op.h modify
Nov 27, 2017
27cf7f3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 27, 2017
20654cf
modify for type check rewrite
Nov 27, 2017
022b48e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 27, 2017
f9c2a5c
modify for code review zcd
Nov 27, 2017
57e68e5
modify for code review by qingqing 2nd
Nov 28, 2017
ee0a794
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 28, 2017
6fc9a9f
modify for del T2 and doc update
Nov 28, 2017
821899c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 28, 2017
d9673ca
format code
Nov 28, 2017
bd56138
format code
Nov 29, 2017
c52ed8d
format code
Nov 29, 2017
2d42fa7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 29, 2017
d2ee3c9
format code
Nov 29, 2017
3206094
format code
Nov 29, 2017
5b449b6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 29, 2017
4ffb73f
format ..
Nov 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ set(DEPS_OPS
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
conv_op
conv_transpose_op
Expand Down Expand Up @@ -211,6 +212,7 @@ op_library(adagrad_op DEPS selected_rows_functor)
op_library(conv_op DEPS vol2col)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
Expand Down
6 changes: 4 additions & 2 deletions paddle/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ if(WITH_GPU)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
Expand All @@ -26,8 +27,9 @@ else()
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
endif()

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
Expand Down
100 changes: 100 additions & 0 deletions paddle/operators/math/unpooling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/unpooling.h"

namespace paddle {
namespace operators {
namespace math {

// All tensors are in NCHW format
template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const T * indices_data = indices.data<T>();
Copy link
Contributor

@NHZlX NHZlX Nov 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You specify indices type to int16 in https://github.com/sweetsky0901/Paddle/blob/27cf7f3376e2005522dbba8c964cb2e449b46380/python/paddle/v2/fluid/tests/test_unpool_op.py#L56 so, there must be wrong using const T * indices_data = indices.data<T>(); here, also, the paddle does not support int16 currently. So, you must specify int32 when test, and use const int * indices_data = indices.data<int>(); or add a new parameter type to the template, like here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done
系统默认对输入类型做了检查,必须一样,如果不一样,需要重写
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Inputframework::Tensor("X")->type()),
ctx.device_context());
}
之前为了逃避这个检查,用float传递的index,然后到具体计算再强转的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是这样的
在注册Op的时候一般注册两个(double/float),那么在op运行的时候到底是用哪个呢(double的还是float),这时候需要调用GetKernelType函数来判断,如果X是float,那就调float的Op,如果是double,那就调double的Op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问这些有文档吗?

T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以写在循环外面,将indices转为EigenMatrix或者EigenVector,之后调用maximum方法返回indices中最大元素,然后判断是否越界。
CUDA实现也可以这样做,避免在cuda kernel里面做越界检测。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用maximum实际上又多了很多次比较,不太赞成这样的写法

output_data[index] = input_data[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么不是 output_data[index] += input_data[i];

另外,Unpool2dMaxFunctor可以不用写,直接复用MaxPool2dWithIndexGradFunctor就行。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样会不会让代码结构上显得有点乱呢,我不太建议这么弄

}
input_data += input_feasize;
indices_data += input_feasize;
output_data += output_feasize;
}
}
}
};



template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* indices_data = indices.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
indices_data += input_feasize;
output_grad_data += output_feasize;
}
}
}
};

template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
156 changes: 156 additions & 0 deletions paddle/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/unpooling.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data,
const T* indices_data,
const int input_height,
const int input_width,
const int channels,
T* output_data,
const int output_height,
const int output_width) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note the order of parameters.

int bsize = input_height * input_width * channels;
int csize = input_height * input_width;
int out_bsize = output_height * output_width * channels;
int out_csize = output_height * output_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉命名有点随意。

int in_c_stride = input_height * input_width;
int in_n_stride = in_c_stride * channels;
int out_c_stride = output_height * output_width;
int out_n_stride = out_c_stride * channels;

数据是NCHW,这里就用了n_stride, c_stride,其他更好的命名也可以。但bsize, csize等觉得有点随意。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / bsize;
int boffset = i % bsize;
int cidx = boffset / csize;
int out_offset = bidx * out_bsize + cidx * out_csize;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < (output_height * output_width));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ASSERT(out_index < out_c_stride);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

output_data[out_offset + out_index] = input_data[i];
}
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data,
const T* indices_data,
const int input_height,
const int input_width,
const int channels,
const T* output_data,
const T* output_grad,
const int output_height,
const int output_width,
T* input_grad) {
int bsize = input_height * input_width * channels;
int csize = input_height * input_width;
int out_bsize = output_height * output_width * channels;
int out_csize = output_height * output_width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / bsize;
int boffset = i % bsize;
int cidx = boffset / csize;
int out_offset = bidx * out_bsize + cidx * out_csize;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < (output_height * output_width));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

input_grad[i] = output_grad[out_offset + out_index];
}
}
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int threads = 1024;
int grids =  (input.numel() + threads - 1) / threads;

代码前后尽量复用~~, 不然改了1024 ,岂不是多改动好几处~~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelUnpool2dMax<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, indices_data,
input_height, input_width, output_channels,
output_data, output_height, output_width);
}
};
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同CPU,也可以复用MaxPool2dWithIndexGradFunctor,可以去掉 Unpool2dMaxFunctor。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样会不会让代码结构上显得有点乱呢,我不太建议这么弄

public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


KernelUnpool2dMaxGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, indices_data,
input_height, input_width, output_channels,
output_data, output_grad_data,
output_height, output_width,
input_grad_data);
}
};

template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;

template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;

} // namespace math
} // namespace operators
} // namespace paddle
44 changes: 44 additions & 0 deletions paddle/operators/math/unpooling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {
namespace math {

template <typename Place, typename T>

class Unpool2dMaxFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * output);
};

template <typename Place, class T>
class Unpool2dMaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor * input_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
Loading