-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[pnorm] fix bug in fp16 & optimize memory #39011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
struct UnsignedPowFunctor { | ||
HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { | ||
this->porder = porder; | ||
} | ||
HOSTDEVICE inline Ty operator()(const Tx x) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来这里两个类型是要支持 fp16 场景吗?现在为什么不用了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前的实现是有问题的,fp16类型在inline function里边会转换为float类型计算,并不需要在外部转换。
TensorReduceFunctorImpl<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>( | ||
*in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis, stream); | ||
|
||
const framework::Tensor* tmp_norm = out_norm; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的tmp_norm
不需要了吧,直接是 out_norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里ins的类型是 std::vector<const framework::Tensor*> ,需要转换为const。
paddle/fluid/operators/p_norm_op.cu
Outdated
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) * | ||
positives.select(ones, negs); | ||
dx->device(place) = dy->broadcast(dim) * (*x).sign() * | ||
((*x).abs() == y->broadcast(dim)).select(ones, zeros); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dx.device(*place) =
(x.abs() == norm.broadcast(bcast)).template cast<T>() * x.sign() *
norm_dy.broadcast(bcast);
也可以用cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里意思是按照这种写法cast就可以获取非零值吗
paddle/fluid/operators/p_norm_op.cu
Outdated
template <typename DeviceContext, typename X, typename Y, typename DX, | ||
typename DY, typename Dim> | ||
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, | ||
const Dim& dim, int size) { | ||
auto ones = dx->constant(static_cast<T>(1.)); | ||
auto negs = dx->constant(static_cast<T>(-1.)); | ||
auto zeros = dx->constant(static_cast<T>(0.)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto ones = dx->constant(static_cast<T>(1.));
auto negs = dx->constant(static_cast<T>(-1.));
auto zeros = dx->constant(static_cast<T>(0.));
都不需要了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the change of atol of float16 unittest (change to 1e-3 is acceptable for float16).
PR types
Bug fixes
PR changes
OPs
Describe