Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions paddle/fluid/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
}
};

template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand Down Expand Up @@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
}
};

template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down Expand Up @@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
27 changes: 14 additions & 13 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

template <typename T, typename AttrType>
template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed,
const AttrType dropout_prob, const T* src,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1);
thrust::uniform_real_distribution<float> dist(0, 1);

int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) {
Expand All @@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");

auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
Expand All @@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {

int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<T, AttrType><<<grid, threads, 0,
context.cuda_device_context().stream()>>>(
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
Y.device(place) = X * (1.0f - dropout_prob);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
};
Expand All @@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout,
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
2 changes: 1 addition & 1 deletion paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename DeviceContext, typename T, typename AttrType>
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Expand Down
151 changes: 133 additions & 18 deletions paddle/fluid/platform/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {

#endif // PADDLE_CUDA_FP16

// Arithmetic operators on ARMv8.2-A CPU
#if defined(PADDLE_WITH_NATIVE_FP16)
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b)));
#else
return float16(float(a) + float(b));
#endif
}

HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b)));
#else
return float16(float(a) - float(b));
#endif
}

HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b)));
#else
return float16(float(a) * float(b));
#endif
}

HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(half(a));
float denom = __half2float(half(b));
return float16(num / denom);
#else
return float16(float(a) / float(b));
#endif
}

HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(half(a)));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}

HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}

HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}

HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}

HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}

HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b));
#else
return float(a) == float(b);
#endif
}

HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b));
#else
return float(a) != float(b);
#endif
}

HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b));
#else
return float(a) < float(b);
#endif
}

HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b));
#else
return float(a) <= float(b);
#endif
}

HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b));
#else
return float(a) > float(b);
#endif
}

HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b));
#else
return float(a) >= float(b);
#endif
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
Expand Down Expand Up @@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}

// Arithmetic operators, software emulated on other CPU
// Arithmetic operators for float16, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
HOST inline float16 operator+(const float16& a, const float16& b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe HOST is unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so. Let me fix that in the next PR.

return float16(float(a) + float(b));
}

HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
HOST inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}

HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
HOST inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}

HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
HOST inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}

HOSTDEVICE inline float16 operator-(const float16& a) {
HOST inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}

HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
HOST inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}

HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
HOST inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}

HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
HOST inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}

HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
HOST inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}

HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
HOST inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}

HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
HOST inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}

HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
HOST inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}

HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
HOST inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}

HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
HOST inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}

HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
HOST inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
Expand Down
33 changes: 33 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest


Expand Down Expand Up @@ -82,5 +83,37 @@ def test_check_output(self):
self.check_output()


class TestFP16DropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.init_test_case()

x = np.random.random(self.input_size).astype("float16")
out = x * (1.0 - self.prob)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {
'dropout_prob': self.prob,
'fix_seed': self.fix_seed,
'is_test': True
}
self.outputs = {'Out': out}

def init_test_case(self):
self.input_size = [32, 64]
self.prob = 0.35
self.fix_seed = True

def test_check_output(self):
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)


class TestFP16DropoutOp2(TestFP16DropoutOp):
def init_test_case(self):
self.input_size = [32, 64, 3]
self.prob = 0.75
self.fix_seed = False


if __name__ == '__main__':
unittest.main()