Skip to content

Commit b2ae891

Browse files
authored
[PHI] Fix paddle.where api for big tensor (PaddlePaddle#72717)
1 parent 195dd45 commit b2ae891

File tree

3 files changed

+71
-57
lines changed

3 files changed

+71
-57
lines changed

paddle/phi/kernels/funcs/select_impl.cu.h

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ struct NonZeroFunctor {
5959
template <typename InT, typename OutT, int VecSize, int IsBoundary>
6060
__device__ void GetBlockCountImpl(const InT *in,
6161
OutT *out,
62-
int num,
63-
int repeat) {
62+
int64_t num,
63+
int64_t repeat) {
6464
InT in_data[VecSize];
6565
OutT temp[VecSize];
6666
OutT result = static_cast<OutT>(0.0f);
6767
using Add = kps::AddFunctor<OutT>;
6868
using Cast = NonZeroFunctor<InT>;
69-
int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;
69+
int64_t store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;
7070

7171
kps::Init<InT, VecSize>(&in_data[0], static_cast<InT>(0.0f));
7272
kps::ReadData<InT, VecSize, 1, IsBoundary>(&in_data[0], in, num);
@@ -92,16 +92,17 @@ __global__ void GetBlockCountKernel(const InT *in,
9292
OutT *out,
9393
int64_t numel,
9494
int64_t main_offset) {
95-
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
96-
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
97-
int repeat = 0;
95+
int64_t size = static_cast<int64_t>(BLOCK_NUM_X) * VecSize;
96+
int64_t data_offset = size * BLOCK_ID_X;
97+
int64_t stride = size * GRID_NUM_X;
98+
int64_t repeat = 0;
9899
for (; data_offset < main_offset; data_offset += stride) {
99100
GetBlockCountImpl<InT, OutT, VecSize, false>(
100-
in + data_offset, out, BLOCK_NUM_X * VecSize, repeat);
101+
in + data_offset, out, size, repeat);
101102
repeat++; // to get the real blockIdx
102103
}
103104

104-
int num = numel - data_offset;
105+
int64_t num = numel - data_offset;
105106
if (num > 0) {
106107
GetBlockCountImpl<InT, OutT, VecSize, true>(
107108
in + data_offset, out, num, repeat);
@@ -150,14 +151,17 @@ __device__ void CumsumImpl(
150151

151152
// Compute this store_offset of this block
152153
template <typename InT, typename OutT, typename Functor, int VecSize>
153-
__global__ void CumsumOneBlock(
154-
const InT *in, OutT *out, int numel, int main_offset, Functor func) {
155-
int stride = BLOCK_NUM_X * VecSize;
156-
int offset = 0;
154+
__global__ void CumsumOneBlock(const InT *in,
155+
OutT *out,
156+
int64_t numel,
157+
int64_t main_offset,
158+
Functor func) {
159+
int64_t stride = BLOCK_NUM_X * VecSize;
160+
int64_t offset = 0;
157161
OutT pre_cumsum = static_cast<OutT>(0);
158162
for (; offset < main_offset; offset += stride) {
159163
CumsumImpl<InT, OutT, Functor, VecSize, false>(
160-
in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize, func);
164+
in + offset, out + offset, &pre_cumsum, stride, func);
161165
}
162166

163167
int num = numel - offset;
@@ -180,10 +184,10 @@ struct SelectCaller {
180184
const MT *mask_data,
181185
const InT *in,
182186
Functor func,
183-
int data_offset,
184-
int store_num,
185-
int thread_fix,
186-
int num) {
187+
int64_t data_offset,
188+
int64_t store_num,
189+
int64_t thread_fix,
190+
int64_t num) {
187191
int64_t in_data[VecSize];
188192
OutT store_data[VecSize * phi::DDim::kMaxRank];
189193
// set index
@@ -260,9 +264,9 @@ __device__ void SelectKernelImpl(OutT *out,
260264
const MT *mask,
261265
const InT *in,
262266
Functor func,
263-
int num,
264-
int data_offset,
265-
int store_rank) {
267+
int64_t num,
268+
int64_t data_offset,
269+
int64_t store_rank) {
266270
const int kCVecSize = 2;
267271
// each thread cumsum 2 data
268272
using IdT = int64_t;
@@ -294,10 +298,9 @@ __device__ void SelectKernelImpl(OutT *out,
294298
// thread_fix
295299
kps::Cumsum<IdT, IdT, Add>(&cumsum_thread[0], &num_thread[0], Add());
296300
// get thread_fix
297-
int thread_fix =
298-
(static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
301+
IdT thread_fix = (cumsum_thread[0] - num_thread[0]) * store_rank;
299302
// get how many data need to store
300-
int store_num = static_cast<int>(num_thread[0]) * store_rank;
303+
IdT store_num = num_thread[0] * store_rank;
301304
// thread store num data, each thread may has different num
302305
// Get store data(index) according to mask_idt
303306
SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, MaskData> select;
@@ -318,18 +321,20 @@ __global__ void SelectKernel(OutT *out,
318321
Functor func,
319322
const int64_t numel,
320323
int64_t main_offset,
321-
int store_rank) {
322-
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
323-
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
324-
int repeat = 0;
325-
int size = VecSize * BLOCK_ID_X;
324+
int64_t store_rank) {
325+
int64_t size = static_cast<int64_t>(BLOCK_ID_X) * VecSize;
326+
int64_t data_offset = size * BLOCK_NUM_X;
327+
int64_t stride = static_cast<int64_t>(BLOCK_NUM_X) * GRID_NUM_X * VecSize;
328+
int64_t repeat = 0;
326329
CT block_store_offset = 0;
327330
for (; data_offset < main_offset; data_offset += stride) {
328331
// Cumsum index
329-
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
332+
int64_t idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
330333
kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
331-
int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
332-
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
334+
int64_t out_fix =
335+
MaskData < 2 ? block_store_offset * store_rank : data_offset;
336+
int64_t in_fix =
337+
MaskData < 2 ? data_offset : block_store_offset * store_rank;
333338
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
334339
out + out_fix,
335340
mask + data_offset,
@@ -341,13 +346,15 @@ __global__ void SelectKernel(OutT *out,
341346
repeat++;
342347
}
343348

344-
int num = numel - data_offset;
349+
int64_t num = numel - data_offset;
345350
if (num > 0) {
346351
// Cumsum index
347-
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
352+
int64_t idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
348353
kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
349-
int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
350-
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
354+
int64_t out_fix =
355+
MaskData < 2 ? block_store_offset * store_rank : data_offset;
356+
int64_t in_fix =
357+
MaskData < 2 ? data_offset : block_store_offset * store_rank;
351358
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
352359
out + out_fix,
353360
mask + data_offset,
@@ -398,19 +405,19 @@ void SelectKernel(const KPDevice &dev_ctx,
398405
int block = 64;
399406
auto stream = dev_ctx.x_context()->xpu_stream;
400407
const int num_per_block = kVecSize * block;
401-
const int need_grids = (numel + num_per_block - 1) / num_per_block;
402-
const int grid = std::min(need_grids, 8);
408+
const int64_t need_grids = (numel + num_per_block - 1) / num_per_block;
409+
const int64_t grid = std::min(need_grids, static_cast<int64_t>(8));
403410
#else
404411
const int block = 256;
405412
const int num_per_block = kVecSize * block;
406-
const int need_grids = (numel + num_per_block - 1) / num_per_block;
407-
const int grid = std::min(need_grids, 256);
413+
const int64_t need_grids = (numel + num_per_block - 1) / num_per_block;
414+
const int64_t grid = std::min(need_grids, static_cast<int64_t>(256));
408415
auto stream = dev_ctx.stream();
409416
#endif
410417
const int64_t main_offset = Floor(numel, num_per_block);
411418
// 1.2 alloc tmp data for CoutBlock
412-
const int size_count_block = need_grids + 1;
413-
std::vector<int> dims_vec = {size_count_block * 2};
419+
const int64_t size_count_block = need_grids + 1;
420+
std::vector<int64_t> dims_vec = {size_count_block * 2};
414421
IntArray dims_array(dims_vec);
415422
DenseTensor count_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
416423
CT *count_data = count_mem.data<CT>();
@@ -424,7 +431,7 @@ void SelectKernel(const KPDevice &dev_ctx,
424431
CT total_true_num = static_cast<CT>(0); // init
425432
const int kCumVesize = 2;
426433
const int block_c = 256;
427-
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
434+
const int64_t main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
428435

429436
using Add = kps::AddFunctor<CT>;
430437
CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(

paddle/phi/kernels/gpu/where_grad_kernel.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
namespace phi {
2121

22-
template <typename T>
22+
template <typename T, typename IndexT>
2323
__global__ void WhereGradCUDAKernel(
24-
const int N, const T* dout, const bool* cond, T* dx, T* dy) {
25-
int idx = blockDim.x * blockIdx.x + threadIdx.x;
24+
const IndexT N, const T* dout, const bool* cond, T* dx, T* dy) {
25+
IndexT idx = blockDim.x * blockIdx.x + threadIdx.x;
2626
for (; idx < N; idx += blockDim.x * gridDim.x) {
2727
if (dx != nullptr) {
2828
dx[idx] = cond[idx] ? dout[idx] : static_cast<T>(0.);
@@ -50,9 +50,15 @@ void WhereGradKernel(const Context& ctx,
5050

5151
auto stream = ctx.stream();
5252
auto config = backends::gpu::GetGpuLaunchConfig1D(ctx, numel);
53-
WhereGradCUDAKernel<T>
54-
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
55-
numel, dout, cond_data, dx, dy);
53+
if (numel <= std::numeric_limits<int>::max()) {
54+
WhereGradCUDAKernel<T, int>
55+
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
56+
numel, dout, cond_data, dx, dy);
57+
} else {
58+
WhereGradCUDAKernel<T, int64_t>
59+
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
60+
numel, dout, cond_data, dx, dy);
61+
}
5662
}
5763

5864
} // namespace phi

paddle/phi/kernels/primitive/datamover_primitives.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ struct BroadcastConfig {
7070
template <typename T>
7171
__device__ __forceinline__ void WriteData(T* dst,
7272
T* __restrict__ src,
73-
int num) {
74-
for (int i = 0; i < num; i++) {
73+
int64_t num) {
74+
for (int64_t i = 0; i < num; i++) {
7575
dst[i] = src[i];
7676
}
7777
}
7878

7979
template <typename T>
8080
__device__ __forceinline__ void ReadData(T* dst,
8181
const T* __restrict__ src,
82-
int num) {
83-
for (int i = 0; i < num; i++) {
82+
int64_t num) {
83+
for (int64_t i = 0; i < num; i++) {
8484
dst[i] = src[i];
8585
}
8686
}
@@ -247,9 +247,9 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
247247
template <typename T, int NX, int NY, bool IsBoundary = false>
248248
__device__ __forceinline__ void ReadData(T* dst,
249249
const T* __restrict__ src,
250-
int num) {
250+
int64_t num) {
251251
if (IsBoundary) { // blockDim.x * NX > num
252-
int thread_offset = threadIdx.x * NX;
252+
int64_t thread_offset = threadIdx.x * NX;
253253
#pragma unroll
254254
for (int idx = 0; idx < NX; ++idx) {
255255
if (idx + thread_offset < num) {
@@ -259,7 +259,7 @@ __device__ __forceinline__ void ReadData(T* dst,
259259
} else { // blockDim,x * NX < num
260260
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
261261
constexpr int kVectorsPerThread = NX / kVectorSize;
262-
int thread_offset = threadIdx.x * kVectorsPerThread;
262+
int64_t thread_offset = threadIdx.x * kVectorsPerThread;
263263

264264
using VecType = details::VectorType<T, kVectorSize>;
265265
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
@@ -848,8 +848,9 @@ __device__ __forceinline__ void ReadDataBc(
848848
* init_data: The register pointer of init data, the size is NX.
849849
*/
850850
template <typename T, int NX, int NY>
851-
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
852-
int thread_offset = block_offset + threadIdx.x * NX;
851+
__device__ __forceinline__ void InitWithDataIndex(T* dst,
852+
int64_t block_offset) {
853+
int64_t thread_offset = block_offset + threadIdx.x * NX;
853854
#pragma unroll
854855
for (int nx = 0; nx < NX; ++nx) {
855856
dst[nx] = static_cast<T>(thread_offset + nx);

0 commit comments

Comments
 (0)