@@ -59,14 +59,14 @@ struct NonZeroFunctor {
5959template <typename InT, typename OutT, int VecSize, int IsBoundary>
6060__device__ void GetBlockCountImpl (const InT *in,
6161 OutT *out,
62- int num,
63- int repeat) {
62+ int64_t num,
63+ int64_t repeat) {
6464 InT in_data[VecSize];
6565 OutT temp[VecSize];
6666 OutT result = static_cast <OutT>(0 .0f );
6767 using Add = kps::AddFunctor<OutT>;
6868 using Cast = NonZeroFunctor<InT>;
69- int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;
69+ int64_t store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;
7070
7171 kps::Init<InT, VecSize>(&in_data[0 ], static_cast <InT>(0 .0f ));
7272 kps::ReadData<InT, VecSize, 1 , IsBoundary>(&in_data[0 ], in, num);
@@ -92,16 +92,17 @@ __global__ void GetBlockCountKernel(const InT *in,
9292 OutT *out,
9393 int64_t numel,
9494 int64_t main_offset) {
95- int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
96- int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
97- int repeat = 0 ;
95+ int64_t size = static_cast <int64_t >(BLOCK_NUM_X) * VecSize;
96+ int64_t data_offset = size * BLOCK_ID_X;
97+ int64_t stride = size * GRID_NUM_X;
98+ int64_t repeat = 0 ;
9899 for (; data_offset < main_offset; data_offset += stride) {
99100 GetBlockCountImpl<InT, OutT, VecSize, false >(
100- in + data_offset, out, BLOCK_NUM_X * VecSize , repeat);
101+ in + data_offset, out, size , repeat);
101102 repeat++; // to get the real blockIdx
102103 }
103104
104- int num = numel - data_offset;
105+ int64_t num = numel - data_offset;
105106 if (num > 0 ) {
106107 GetBlockCountImpl<InT, OutT, VecSize, true >(
107108 in + data_offset, out, num, repeat);
@@ -150,14 +151,17 @@ __device__ void CumsumImpl(
150151
151152// Compute this store_offset of this block
152153template <typename InT, typename OutT, typename Functor, int VecSize>
153- __global__ void CumsumOneBlock (
154- const InT *in, OutT *out, int numel, int main_offset, Functor func) {
155- int stride = BLOCK_NUM_X * VecSize;
156- int offset = 0 ;
154+ __global__ void CumsumOneBlock (const InT *in,
155+ OutT *out,
156+ int64_t numel,
157+ int64_t main_offset,
158+ Functor func) {
159+ int64_t stride = BLOCK_NUM_X * VecSize;
160+ int64_t offset = 0 ;
157161 OutT pre_cumsum = static_cast <OutT>(0 );
158162 for (; offset < main_offset; offset += stride) {
159163 CumsumImpl<InT, OutT, Functor, VecSize, false >(
160- in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize , func);
164+ in + offset, out + offset, &pre_cumsum, stride , func);
161165 }
162166
163167 int num = numel - offset;
@@ -180,10 +184,10 @@ struct SelectCaller {
180184 const MT *mask_data,
181185 const InT *in,
182186 Functor func,
183- int data_offset,
184- int store_num,
185- int thread_fix,
186- int num) {
187+ int64_t data_offset,
188+ int64_t store_num,
189+ int64_t thread_fix,
190+ int64_t num) {
187191 int64_t in_data[VecSize];
188192 OutT store_data[VecSize * phi::DDim::kMaxRank ];
189193 // set index
@@ -260,9 +264,9 @@ __device__ void SelectKernelImpl(OutT *out,
260264 const MT *mask,
261265 const InT *in,
262266 Functor func,
263- int num,
264- int data_offset,
265- int store_rank) {
267+ int64_t num,
268+ int64_t data_offset,
269+ int64_t store_rank) {
266270 const int kCVecSize = 2 ;
267271 // each thread cumsum 2 data
268272 using IdT = int64_t ;
@@ -294,10 +298,9 @@ __device__ void SelectKernelImpl(OutT *out,
294298 // thread_fix
295299 kps::Cumsum<IdT, IdT, Add>(&cumsum_thread[0 ], &num_thread[0 ], Add ());
296300 // get thread_fix
297- int thread_fix =
298- (static_cast <int >(cumsum_thread[0 ] - num_thread[0 ]) * store_rank);
301+ IdT thread_fix = (cumsum_thread[0 ] - num_thread[0 ]) * store_rank;
299302 // get how many data need to store
300- int store_num = static_cast < int >( num_thread[0 ]) * store_rank;
303+ IdT store_num = num_thread[0 ] * store_rank;
301304 // thread store num data, each thread may has different num
302305 // Get store data(index) according to mask_idt
303306 SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, MaskData> select;
@@ -318,18 +321,20 @@ __global__ void SelectKernel(OutT *out,
318321 Functor func,
319322 const int64_t numel,
320323 int64_t main_offset,
321- int store_rank) {
322- int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
323- int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize ;
324- int repeat = 0 ;
325- int size = VecSize * BLOCK_ID_X ;
324+ int64_t store_rank) {
325+ int64_t size = static_cast < int64_t >( BLOCK_ID_X) * VecSize;
326+ int64_t data_offset = size * BLOCK_NUM_X ;
327+ int64_t stride = static_cast < int64_t >(BLOCK_NUM_X) * GRID_NUM_X * VecSize ;
328+ int64_t repeat = 0 ;
326329 CT block_store_offset = 0 ;
327330 for (; data_offset < main_offset; data_offset += stride) {
328331 // Cumsum index
329- int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
332+ int64_t idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
330333 kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1 );
331- int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
332- int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
334+ int64_t out_fix =
335+ MaskData < 2 ? block_store_offset * store_rank : data_offset;
336+ int64_t in_fix =
337+ MaskData < 2 ? data_offset : block_store_offset * store_rank;
333338 SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false >(
334339 out + out_fix,
335340 mask + data_offset,
@@ -341,13 +346,15 @@ __global__ void SelectKernel(OutT *out,
341346 repeat++;
342347 }
343348
344- int num = numel - data_offset;
349+ int64_t num = numel - data_offset;
345350 if (num > 0 ) {
346351 // Cumsum index
347- int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
352+ int64_t idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
348353 kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1 );
349- int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
350- int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
354+ int64_t out_fix =
355+ MaskData < 2 ? block_store_offset * store_rank : data_offset;
356+ int64_t in_fix =
357+ MaskData < 2 ? data_offset : block_store_offset * store_rank;
351358 SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true >(
352359 out + out_fix,
353360 mask + data_offset,
@@ -398,19 +405,19 @@ void SelectKernel(const KPDevice &dev_ctx,
398405 int block = 64 ;
399406 auto stream = dev_ctx.x_context ()->xpu_stream ;
400407 const int num_per_block = kVecSize * block;
401- const int need_grids = (numel + num_per_block - 1 ) / num_per_block;
402- const int grid = std::min (need_grids, 8 );
408+ const int64_t need_grids = (numel + num_per_block - 1 ) / num_per_block;
409+ const int64_t grid = std::min (need_grids, static_cast < int64_t >( 8 ) );
403410#else
404411 const int block = 256 ;
405412 const int num_per_block = kVecSize * block;
406- const int need_grids = (numel + num_per_block - 1 ) / num_per_block;
407- const int grid = std::min (need_grids, 256 );
413+ const int64_t need_grids = (numel + num_per_block - 1 ) / num_per_block;
414+ const int64_t grid = std::min (need_grids, static_cast < int64_t >( 256 ) );
408415 auto stream = dev_ctx.stream ();
409416#endif
410417 const int64_t main_offset = Floor (numel, num_per_block);
411418 // 1.2 alloc tmp data for CoutBlock
412- const int size_count_block = need_grids + 1 ;
413- std::vector<int > dims_vec = {size_count_block * 2 };
419+ const int64_t size_count_block = need_grids + 1 ;
420+ std::vector<int64_t > dims_vec = {size_count_block * 2 };
414421 IntArray dims_array (dims_vec);
415422 DenseTensor count_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
416423 CT *count_data = count_mem.data <CT>();
@@ -424,7 +431,7 @@ void SelectKernel(const KPDevice &dev_ctx,
424431 CT total_true_num = static_cast <CT>(0 ); // init
425432 const int kCumVesize = 2 ;
426433 const int block_c = 256 ;
427- const int main_offset_c = Floor (size_count_block, (kCumVesize * block_c));
434+ const int64_t main_offset_c = Floor (size_count_block, (kCumVesize * block_c));
428435
429436 using Add = kps::AddFunctor<CT>;
430437 CumsumOneBlock<CT, CT, Add, kCumVesize ><<<1 , block_c, 0 , stream>>>(
0 commit comments