@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#define EIGEN_USE_GPU
16+ #include < vector>
1617#include " paddle/fluid/framework/data_type.h"
1718#include " paddle/fluid/operators/math/math_function.h"
1819#include " paddle/fluid/operators/math/math_function_impl.h"
@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
267268 const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
268269 const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
269270 const float16 alpha, const float16* A, const float16* B, const float16 beta,
270- float16* C, const int batchCount, const int strideA, const int strideB) {
271+ float16* C, const int batchCount, const int64_t strideA,
272+ const int64_t strideB) {
271273#if CUDA_VERSION >= 8000
272274 // Note that cublas follows fortran order, so the order is different from
273275 // the cblas convention.
@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
278280 (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
279281 cublasOperation_t cuTransB =
280282 (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
281- const int strideC = M * N;
283+ const int64_t strideC = M * N;
282284
283285 const half h_alpha = static_cast <const half>(alpha);
284286 const half h_beta = static_cast <const half>(beta);
@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
303305 const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
304306 const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
305307 const float alpha, const float * A, const float * B, const float beta,
306- float * C, const int batchCount, const int strideA, const int strideB) {
308+ float * C, const int batchCount, const int64_t strideA,
309+ const int64_t strideB) {
307310#if CUDA_VERSION >= 8000
308311 // Note that cublas follows fortran order, so the order is different from
309312 // the cblas convention.
@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
314317 (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
315318 cublasOperation_t cuTransB =
316319 (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
317- const int strideC = M * N;
320+ const int64_t strideC = M * N;
318321
319322 PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (
320323 context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
329332 const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
330333 const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
331334 const double alpha, const double * A, const double * B, const double beta,
332- double * C, const int batchCount, const int strideA, const int strideB) {
335+ double * C, const int batchCount, const int64_t strideA,
336+ const int64_t strideB) {
333337#if CUDA_VERSION >= 8000
334338 // Note that cublas follows fortran order, so the order is different from
335339 // the cblas convention.
@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
340344 (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
341345 cublasOperation_t cuTransB =
342346 (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
343- const int strideC = M * N;
347+ const int64_t strideC = M * N;
344348
345349 PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (
346350 context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
0 commit comments