Skip to content

Commit 4762f6a

Browse files
authored
【Paddle Tensor 第二期 API支持 0-size Tensor】Fix frobenius_norm to support 0-size tensor (#72570)
* [Tensor] Fix frobenius_norm to support 0-size tensor * [Tensor] Drop resize of fro-norm * [Tensor] Import utils with non-relative path * drop resize in gpu
1 parent a04c07a commit 4762f6a

File tree

3 files changed

+410
-355
lines changed

3 files changed

+410
-355
lines changed

paddle/phi/kernels/gpu/frobenius_norm_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ void FrobeniusNormKernel(const Context& dev_ctx,
2828
bool keep_dim,
2929
bool reduce_all,
3030
DenseTensor* out) {
31+
if (x.numel() == 0) {
32+
dev_ctx.template Alloc<T>(out);
33+
phi::funcs::SetConstant<Context, T>()(dev_ctx, out, static_cast<T>(0));
34+
return;
35+
}
3136
reduce_all = recompute_reduce_all(x, dims.GetData(), reduce_all);
3237
auto out_dtype = x.dtype();
3338
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(

paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ void FrobeniusNormKernel(const Context& ctx,
2727
bool keep_dim,
2828
bool reduce_all,
2929
DenseTensor* out) {
30+
if (x.numel() == 0) {
31+
ctx.template Alloc<T>(out);
32+
phi::funcs::SetConstant<Context, T>()(ctx, out, 0);
33+
return;
34+
}
3035
reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all);
3136
Reduce<Context, T, funcs::FrobeniusNormFunctor>(
3237
ctx, x, reduce_all, axis.GetData(), keep_dim, x.dtype(), out);

0 commit comments

Comments
 (0)