Skip to content

Conversation

@kexinzhao
Copy link
Contributor

@kexinzhao kexinzhao commented Mar 19, 2018

fix #9222

Added device function for multiplying two float16 numbers on GPU device, which is needed in the following code:

Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);

@kexinzhao kexinzhao requested a review from chengduoZH March 19, 2018 23:54
@kexinzhao kexinzhao added the 预测 原名Inference,包含Capi预测问题等 label Mar 20, 2018

def test_check_output(self):
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestFP16DropoutOp1 and TestFP16DropoutOp2 are very similar, and they can be inherited relationships, which can reduce the amount of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! Done.

Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

// Arithmetic operators for float16, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
HOST inline float16 operator+(const float16& a, const float16& b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe HOST is unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so. Let me fix that in the next PR.

@kexinzhao kexinzhao merged commit 5271c32 into PaddlePaddle:develop Mar 20, 2018
@kexinzhao kexinzhao deleted the dropout_fp16 branch March 20, 2018 05:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

预测 原名Inference,包含Capi预测问题等

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Need float16 support in dropout operator

2 participants