Skip to content

Commit 0e6f7fb

Browse files
[Cherry-pick Fleety_12] Fix Bigtensor (PaddlePaddle#76363) (PaddlePaddle#76371)
* big tensor: moe_permute/moe_unpermute/repeat_interleave/fused_transpose_wlch_split_quant * fix * fix int64 * fix int64 to int
1 parent b314b42 commit 0e6f7fb

File tree

4 files changed

+53
-18
lines changed

4 files changed

+53
-18
lines changed

paddle/phi/kernels/fusion/gpu/fused_transpose_wlch_split_quant_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ __global__ void __launch_bounds__(512)
100100
reinterpret_cast<__nv_fp8_e4m3**>(meta + num_experts);
101101
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);
102102

103-
const size_t block_off_x = blockIdx.x * size_t(128);
104-
const size_t block_off_y = blockIdx.y * 128;
103+
const size_t block_off_x = static_cast<size_t>(blockIdx.x) * 128;
104+
const size_t block_off_y = static_cast<size_t>(blockIdx.y) * 128;
105105

106106
// 1. Load 128x128 block from input.
107107
for (uint32_t i = 0; i < 8; i++) {
@@ -156,7 +156,7 @@ __global__ void __launch_bounds__(512)
156156
off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2;
157157
}
158158
float scale_out = 1.0f / col_scale[off];
159-
size_t idx_y = blockIdx.x - expert_off / 128;
159+
size_t idx_y = static_cast<size_t>(blockIdx.x) - expert_off / 128;
160160
size_t idx_x = block_off_y + threadIdx.y * 32 + threadIdx.x;
161161
size_t idx = idx_y * H + idx_x;
162162
if (idx_x < H) {

paddle/phi/kernels/gpu/moe_permute_kernel.cu

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,20 @@ void MoePermuteKernel(const Context &dev_ctx,
245245
DenseTensor *zipped_expertwise_rowmap,
246246
DenseTensor *token_prob_unzipped,
247247
DenseTensor *XScale_unzipped) {
248-
const int rows = X.dims()[0];
249-
const int cols = X.dims()[1];
248+
const int64_t rows = X.dims()[0];
249+
const int64_t cols = X.dims()[1];
250+
PADDLE_ENFORCE_LE(
251+
rows,
252+
std::numeric_limits<int32_t>::max(),
253+
common::errors::InvalidArgument("X.dims()[0] should be less than "
254+
"INT_MAX, received X.dims()[0]: (%ld)",
255+
rows));
256+
PADDLE_ENFORCE_LE(
257+
cols,
258+
std::numeric_limits<int32_t>::max(),
259+
common::errors::InvalidArgument("X.dims()[1] should be less than "
260+
"INT_MAX, received X.dims()[1]: (%ld)",
261+
cols));
250262
PADDLE_ENFORCE_LE(
251263
num_experts,
252264
MAX_NUM_EXPERTS,
@@ -256,7 +268,13 @@ void MoePermuteKernel(const Context &dev_ctx,
256268
"value.",
257269
MAX_NUM_EXPERTS,
258270
num_experts));
259-
const int quanted_cols = (XScale) ? XScale.get_ptr()->dims()[1] : 0;
271+
const int64_t quanted_cols = (XScale) ? XScale.get_ptr()->dims()[1] : 0;
272+
PADDLE_ENFORCE_LE(
273+
quanted_cols,
274+
std::numeric_limits<int32_t>::max(),
275+
common::errors::InvalidArgument("quanted_cols should be less than "
276+
"INT_MAX, received quanted_cols: (%ld)",
277+
quanted_cols));
260278

261279
// Expert base offset initialization, tensor numeric range [0, max_token_num]
262280
int expert_offset[MAX_NUM_EXPERTS];
@@ -281,7 +299,12 @@ void MoePermuteKernel(const Context &dev_ctx,
281299
dev_ctx.stream()));
282300
// ------------------- resource allocate -------------------------
283301
const int output_rows = tokens_cumulated;
284-
const int topk = expert_routemap_topk.dims()[1];
302+
const int64_t topk = expert_routemap_topk.dims()[1];
303+
PADDLE_ENFORCE_LE(
304+
topk,
305+
std::numeric_limits<int32_t>::max(),
306+
common::errors::InvalidArgument(
307+
"topk should be less than INT_MAX, received topk: (%ld)", topk));
285308
token_prob_unzipped->Resize({output_rows});
286309
if (do_gather) { // no gather, no resize.
287310
X_unzipped->Resize({output_rows, cols});
@@ -346,11 +369,11 @@ void MoePermuteKernel(const Context &dev_ctx,
346369
token_prob_unzipped,
347370
XScale_unzipped,
348371
&global_expertwise_block_cumsum,
349-
rows,
350-
cols,
351-
topk,
372+
static_cast<int>(rows),
373+
static_cast<int>(cols),
374+
static_cast<int>(topk),
352375
num_experts,
353-
quanted_cols,
376+
static_cast<int>(quanted_cols),
354377
do_gather);
355378
}
356379
#undef CUMSUM_BLOCK_SIZE

paddle/phi/kernels/gpu/moe_unpermute_kernel.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,13 @@ void MoeUnpermuteKernel(const Context &dev_ctx,
226226
const bool MP,
227227
DenseTensor *zipped_tokens,
228228
DenseTensor *zipped_probs_topk) {
229-
const int rows = unzipped_tokens.dims()[0];
230-
const int cols = unzipped_tokens.dims()[1];
229+
const int64_t cols = unzipped_tokens.dims()[1];
230+
PADDLE_ENFORCE_LE(cols,
231+
std::numeric_limits<int32_t>::max(),
232+
common::errors::InvalidArgument(
233+
"unzipped_tokens.dims()[1] should be less than "
234+
"INT_MAX, received unzipped_tokens.dims()[1]: (%ld)",
235+
cols));
231236
PADDLE_ENFORCE_LE(
232237
num_experts,
233238
MAX_NUM_EXPERTS,
@@ -237,7 +242,12 @@ void MoeUnpermuteKernel(const Context &dev_ctx,
237242
"value.",
238243
MAX_NUM_EXPERTS,
239244
num_experts));
240-
const int topk = expert_routemap_topk.dims()[1];
245+
const int64_t topk = expert_routemap_topk.dims()[1];
246+
PADDLE_ENFORCE_LE(
247+
topk,
248+
std::numeric_limits<int32_t>::max(),
249+
common::errors::InvalidArgument(
250+
"topk should be less than INT_MAX, received topk: (%ld)", topk));
241251
dev_ctx.template Alloc<T>(zipped_tokens);
242252
dev_ctx.template Alloc<float>(zipped_probs_topk);
243253
if (unzipped_tokens.numel() == 0) return; // 0-size tensor
@@ -258,8 +268,8 @@ void MoeUnpermuteKernel(const Context &dev_ctx,
258268
zipped_probs_topk,
259269
total_zipped_tokens_num,
260270
num_experts,
261-
cols,
262-
topk,
271+
static_cast<int>(cols),
272+
static_cast<int>(topk),
263273
MP);
264274
}
265275
} // namespace phi

paddle/phi/kernels/gpu/repeat_interleave_kernel.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ __global__ void index_select_cuda_kernel(const T* input,
3737
int64_t stride,
3838
int64_t size,
3939
int64_t delta) {
40-
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
40+
const int64_t idx =
41+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
4142
if (idx >= N) {
4243
return;
4344
}
@@ -198,7 +199,8 @@ __global__ void RepeatInterleaveVecKernel(const T* __restrict__ input,
198199
const int repeats) {
199200
using VecType = kps::details::VectorType<T, VecSize>;
200201

201-
const int64_t tid = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
202+
const int64_t tid =
203+
(static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
202204
if (tid >= numel) return;
203205

204206
VecType* vec_output = reinterpret_cast<VecType*>(output);

0 commit comments

Comments
 (0)