@@ -224,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
224224 const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
225225 const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
226226 const float alpha, const float * A, const float * B, const float beta,
227- float * C, const int batchCount, const int strideA, const int strideB) {
227+ float * C, const int batchCount, const int64_t strideA,
228+ const int64_t strideB) {
228229 for (int k = 0 ; k < batchCount; ++k) {
229230 const float * Ak = &A[k * strideA];
230231 const float * Bk = &B[k * strideB];
@@ -239,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
239240 const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
240241 const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
241242 const double alpha, const double * A, const double * B, const double beta,
242- double * C, const int batchCount, const int strideA, const int strideB) {
243+ double * C, const int batchCount, const int64_t strideA,
244+ const int64_t strideB) {
243245 for (int k = 0 ; k < batchCount; ++k) {
244246 const double * Ak = &A[k * strideA];
245247 const double * Bk = &B[k * strideB];
0 commit comments