Skip to content

[pnorm] fix bug in fp16 & optimize memory #39011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 25, 2022

Conversation

LemonNoel
Copy link
Contributor

@LemonNoel LemonNoel commented Jan 18, 2022

PR types

Bug fixes

PR changes

OPs

Describe

  • Fix bug for pnorm op in float16.
  • Optimize memory for pnorm op.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZHUI ZHUI requested review from ZHUI and wawltor January 20, 2022 09:42
struct UnsignedPowFunctor {
HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) {
this->porder = porder;
}
HOSTDEVICE inline Ty operator()(const Tx x) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来这里两个类型是要支持 fp16 场景吗?现在为什么不用了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的实现是有问题的,fp16类型在inline function里边会转换为float类型计算,并不需要在外部转换。

TensorReduceFunctorImpl<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
*in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis, stream);

const framework::Tensor* tmp_norm = out_norm;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的tmp_norm不需要了吧,直接是 out_norm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里ins的类型是 std::vector<const framework::Tensor*> ,需要转换为const。

dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) *
positives.select(ones, negs);
dx->device(place) = dy->broadcast(dim) * (*x).sign() *
((*x).abs() == y->broadcast(dim)).select(ones, zeros);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/p_norm_op.h#L126-L128

      dx.device(*place) =
          (x.abs() == norm.broadcast(bcast)).template cast<T>() * x.sign() *
          norm_dy.broadcast(bcast);

也可以用cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里意思是按照这种写法cast就可以获取非零值吗

template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
auto ones = dx->constant(static_cast<T>(1.));
auto negs = dx->constant(static_cast<T>(-1.));
auto zeros = dx->constant(static_cast<T>(0.));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    auto ones = dx->constant(static_cast<T>(1.));
    auto negs = dx->constant(static_cast<T>(-1.));
    auto zeros = dx->constant(static_cast<T>(0.));

都不需要了吧

ZHUI
ZHUI previously approved these changes Jan 21, 2022
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

wanghuancoder
wanghuancoder previously approved these changes Jan 21, 2022
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LemonNoel LemonNoel dismissed stale reviews from wanghuancoder and ZHUI via 9b4d030 January 24, 2022 02:47
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the change of atol of float16 unittest (change to 1e-3 is acceptable for float16).

@ZHUI ZHUI merged commit 3825b40 into PaddlePaddle:develop Jan 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants