Skip to content

Commit 080a899

Browse files
committed
use shared memory
1 parent 144e1e0 commit 080a899

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/fluid/operators/sgd_group_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ __global__ void SGDGroupKernel(T** grads, T** params, T** learning_rate,
5959
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
6060
extern __shared__ int s_p_numbers[];
6161

62-
if (threadIdx.x < para_num) {
62+
if (threadIdx.x < para_num + 1) {
6363
s_p_numbers[threadIdx.x] = p_numbers[threadIdx.x];
6464
}
6565
__syncthreads();
@@ -149,7 +149,7 @@ class SGDGroupOpCUDAKernel : public framework::OpKernel<T> {
149149

150150
int grid = std::min((p_ele_num + block - 1) / block, max_blocks);
151151

152-
SGDGroupKernel<T><<<grid, block, p_num * sizeof(int),
152+
SGDGroupKernel<T><<<grid, block, (p_num + 1) * sizeof(int),
153153
ctx.cuda_device_context().stream()>>>(
154154
grads_gpu, params_gpu, lrs_data_gpu, param_num_gpu, p_num, p_ele_num,
155155
param_out_gpu);

0 commit comments

Comments
 (0)