Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/math_function.h"
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h"
Expand Down Expand Up @@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) {
float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
PADDLE_THROW("float16 batched_gemm not supported on CPU");
}

Expand All @@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
Expand All @@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
Expand All @@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) {
const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB];
Expand All @@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) {
const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB];
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
Expand Down Expand Up @@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) {
float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
Expand All @@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
const int64_t strideC = M * N;

const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
Expand All @@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) {
float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
Expand All @@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
const int64_t strideC = M * N;

PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
Expand All @@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) {
double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
Expand All @@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
const int64_t strideC = M * N;

PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/math/math_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License. */

#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
#include <cblas.h> // NOLINT
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
Expand All @@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#endif

#include <cmath>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
Expand Down Expand Up @@ -78,8 +79,8 @@ template <typename DeviceContext, typename T>
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N,
const int K, const T alpha, const T* A, const T* B,
const T beta, T* C, const int batchCount, const int strideA,
const int strideB);
const T beta, T* C, const int batchCount,
const int64_t strideA, const int64_t strideB);

template <typename DeviceContext, typename T>
void gemv(const DeviceContext& context, const bool trans_a, const int M,
Expand Down