Skip to content

Commit d4aaad3

Browse files
authored
[API] Fix int overflow and float16 support for paddle.frac (#72815)
1 parent 7869937 commit d4aaad3

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

paddle/phi/kernels/gpu/trunc_kernel.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/trunc_kernel.h"
16-
1716
#include "paddle/phi/backends/gpu/gpu_context.h"
1817
#include "paddle/phi/backends/gpu/gpu_info.h"
18+
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
1919
#include "paddle/phi/backends/gpu/gpu_primitives.h"
2020
#include "paddle/phi/common/amp_type_traits.h"
2121
#include "paddle/phi/core/kernel_registry.h"
@@ -59,7 +59,7 @@ class TruncFunctor<int64_t> {
5959

6060
template <typename T>
6161
__global__ void Trunc(const T* x, T* out, int64_t N) {
62-
CUDA_KERNEL_LOOP(index, N) {
62+
CUDA_KERNEL_LOOP_TYPE(index, N, int64_t) {
6363
TruncFunctor<T> functor(x[index]);
6464
out[index] = functor();
6565
}
@@ -73,11 +73,10 @@ void TruncKernel(const Context& dev_ctx,
7373
auto* out_data = dev_ctx.template Alloc<T>(out);
7474

7575
int64_t numel = x.numel();
76+
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
7677

77-
int threads = PADDLE_CUDA_NUM_THREADS;
78-
int blocks = (numel + threads - 1) / threads;
79-
80-
Trunc<<<blocks, threads>>>(x_data, out_data, numel);
78+
Trunc<<<config.block_per_grid, config.thread_per_block>>>(
79+
x_data, out_data, numel);
8180
}
8281

8382
} // namespace phi

python/paddle/tensor/math.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6946,13 +6946,15 @@ def frac(x: Tensor, name: str | None = None) -> Tensor:
69466946
paddle.int64,
69476947
paddle.float32,
69486948
paddle.float64,
6949+
paddle.float16,
69496950
DataType.INT32,
69506951
DataType.INT64,
69516952
DataType.FLOAT32,
69526953
DataType.FLOAT64,
6954+
DataType.FLOAT16,
69536955
]:
69546956
raise TypeError(
6955-
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}"
6957+
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64', 'float16'], but got {x.dtype}"
69566958
)
69576959
if in_dynamic_or_pir_mode():
69586960
y = _C_ops.trunc(x)
@@ -6963,7 +6965,7 @@ def frac(x: Tensor, name: str | None = None) -> Tensor:
69636965

69646966
helper = LayerHelper("trunc", **locals())
69656967
check_variable_and_dtype(
6966-
x, "X", ['int32', 'int64', 'float32', 'float64'], 'trunc'
6968+
x, "X", ['int32', 'int64', 'float32', 'float64', 'float16'], 'trunc'
69676969
)
69686970
y = helper.create_variable_for_type_inference(dtype=x.dtype)
69696971
helper.append_op(
@@ -6984,9 +6986,10 @@ def frac_(x: Tensor, name: str | None = None) -> Tensor:
69846986
paddle.int64,
69856987
paddle.float32,
69866988
paddle.float64,
6989+
paddle.float16,
69876990
]:
69886991
raise TypeError(
6989-
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {x.dtype}"
6992+
f"The data type of input must be one of ['int32', 'int64', 'float32', 'float64', 'float16'], but got {x.dtype}"
69906993
)
69916994
if in_dynamic_mode():
69926995
y = _C_ops.trunc(x)

0 commit comments

Comments
 (0)