Skip to content

Commit a046605

Browse files
committed
Refine CUDA Related libraries
1 parent 67bbcbb commit a046605

File tree

10 files changed

+201
-150
lines changed

10 files changed

+201
-150
lines changed

paddle/platform/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
add_subdirectory(dynload)
22

3-
nv_test(cuda_test SRCS cuda_test.cu)
3+
nv_test(cuda_test SRCS cuda_test.cu DEPS dyload_cuda)
44

55
cc_library(place SRCS place.cc)
66
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
7+
IF(WITH_GPU)
8+
set(GPU_CTX_DEPS dyload_cuda dynamic_loader )
9+
ELSE()
10+
set(GPU_CTX_DEPS)
11+
ENDIF()
712

8-
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags)
13+
cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS})
14+
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags)

paddle/platform/cuda.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ inline void throw_on_error(cudaError_t e, const char* message) {
2828
}
2929
}
3030

31-
int GetDeviceCount(void) {
31+
inline int GetDeviceCount(void) {
3232
int count;
3333
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
3434
return count;
3535
}
3636

37-
int GetCurrentDeviceId(void) {
37+
inline int GetCurrentDeviceId(void) {
3838
int device_id;
3939
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
4040
return device_id;
4141
}
4242

43-
void SetDeviceId(int device_id) {
43+
inline void SetDeviceId(int device_id) {
4444
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
4545
}
4646

paddle/platform/device_context.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <paddle/platform/device_context.h>
2+
3+
namespace paddle {
4+
namespace platform {
5+
namespace dynload {
6+
namespace dummy {
7+
// Make DeviceContext A library.
8+
int DUMMY_VAR_FOR_DEV_CTX = 0;
9+
10+
} // namespace dummy
11+
} // namespace dynload
12+
} // namespace platform
13+
} // namespace paddle
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
2+
nv_library(dyload_cuda SRCS cublas.cc cudnn.cc curand.cc)

paddle/platform/dynload/cublas.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <paddle/platform/dynload/cublas.h>
2+
3+
namespace paddle {
4+
namespace platform {
5+
namespace dynload {
6+
std::once_flag cublas_dso_flag;
7+
void *cublas_dso_handle = nullptr;
8+
9+
#define DEFINE_WRAP(__name) DynLoad__##__name __name;
10+
11+
CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
12+
13+
} // namespace dynload
14+
} // namespace platform
15+
} // namespace paddle

paddle/platform/dynload/cublas.h

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ namespace paddle {
2323
namespace platform {
2424
namespace dynload {
2525

26-
std::once_flag cublas_dso_flag;
27-
void *cublas_dso_handle = nullptr;
26+
extern std::once_flag cublas_dso_flag;
27+
extern void *cublas_dso_handle;
2828

2929
/**
3030
* The following macro definition can generate structs
@@ -34,73 +34,54 @@ void *cublas_dso_handle = nullptr;
3434
* note: default dynamic linked libs
3535
*/
3636
#ifdef PADDLE_USE_DSO
37-
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
37+
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
3838
struct DynLoad__##__name { \
3939
template <typename... Args> \
40-
cublasStatus_t operator()(Args... args) { \
40+
inline cublasStatus_t operator()(Args... args) { \
4141
typedef cublasStatus_t (*cublasFunc)(Args...); \
4242
std::call_once(cublas_dso_flag, \
4343
paddle::platform::dynload::GetCublasDsoHandle, \
4444
&cublas_dso_handle); \
4545
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
4646
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
4747
} \
48-
} __name; // struct DynLoad__##__name
48+
}; \
49+
extern DynLoad__##__name __name
4950
#else
50-
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
51-
struct DynLoad__##__name { \
52-
template <typename... Args> \
53-
cublasStatus_t operator()(Args... args) { \
54-
return __name(args...); \
55-
} \
56-
} __name; // struct DynLoad__##__name
51+
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
52+
struct DynLoad__##__name { \
53+
inline template <typename... Args> \
54+
cublasStatus_t operator()(Args... args) { \
55+
return __name(args...); \
56+
} \
57+
}; \
58+
extern DynLoad__##__name __name
5759
#endif
5860

59-
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
60-
61-
// include all needed cublas functions in HPPL
62-
// clang-format off
6361
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
64-
__macro(cublasSgemv) \
65-
__macro(cublasDgemv) \
66-
__macro(cublasSgemm) \
67-
__macro(cublasDgemm) \
68-
__macro(cublasSgeam) \
69-
__macro(cublasDgeam) \
70-
71-
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
72-
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
73-
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
74-
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
75-
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
76-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
77-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
78-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
79-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
80-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
81-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
82-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
83-
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
84-
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
62+
__macro(cublasSgemv); \
63+
__macro(cublasDgemv); \
64+
__macro(cublasSgemm); \
65+
__macro(cublasDgemm); \
66+
__macro(cublasSgeam); \
67+
__macro(cublasDgeam); \
68+
__macro(cublasCreate); \
69+
__macro(cublasDestroy); \
70+
__macro(cublasSetStream); \
71+
__macro(cublasSetPointerMode); \
72+
__macro(cublasGetPointerMode); \
73+
__macro(cublasSgemmBatched); \
74+
__macro(cublasDgemmBatched); \
75+
__macro(cublasCgemmBatched); \
76+
__macro(cublasZgemmBatched); \
77+
__macro(cublasSgetrfBatched); \
78+
__macro(cublasSgetriBatched); \
79+
__macro(cublasDgetrfBatched); \
80+
__macro(cublasDgetriBatched)
8581

86-
#undef DYNAMIC_LOAD_CUBLAS_WRAP
87-
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
88-
#undef CUBLAS_BLAS_ROUTINE_EACH
82+
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP);
8983

90-
// clang-format on
91-
#ifndef PADDLE_TYPE_DOUBLE
92-
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
93-
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
94-
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
95-
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
96-
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
97-
#else
98-
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
99-
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
100-
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
101-
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
102-
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
103-
#endif
84+
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
10485
} // namespace dynload
10586
} // namespace platform
10687
} // namespace paddle

paddle/platform/dynload/cudnn.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <paddle/platform/dynload/cudnn.h>
2+
3+
namespace paddle {
4+
namespace platform {
5+
namespace dynload {
6+
std::once_flag cudnn_dso_flag;
7+
void* cudnn_dso_handle = nullptr;
8+
9+
#define DEFINE_WRAP(__name) DynLoad__##__name __name
10+
11+
CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP);
12+
CUDNN_DNN_ROUTINE_EACH_R2(DEFINE_WRAP);
13+
14+
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
15+
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DEFINE_WRAP);
16+
#endif
17+
18+
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
19+
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DEFINE_WRAP);
20+
#endif
21+
22+
#ifdef CUDNN_DNN_ROUTINE_EACH_R5
23+
CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP);
24+
#endif
25+
26+
} // namespace dynload
27+
} // namespace platform
28+
} // namespace paddle

paddle/platform/dynload/cudnn.h

Lines changed: 62 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ namespace paddle {
2323
namespace platform {
2424
namespace dynload {
2525

26-
std::once_flag cudnn_dso_flag;
27-
void* cudnn_dso_handle = nullptr;
26+
extern std::once_flag cudnn_dso_flag;
27+
extern void* cudnn_dso_handle;
2828

2929
#ifdef PADDLE_USE_DSO
3030

31-
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
31+
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
3232
struct DynLoad__##__name { \
3333
template <typename... Args> \
3434
auto operator()(Args... args) -> decltype(__name(args...)) { \
@@ -39,98 +39,93 @@ void* cudnn_dso_handle = nullptr;
3939
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
4040
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
4141
} \
42-
} __name; /* struct DynLoad__##__name */
42+
}; \
43+
extern struct DynLoad__##__name __name
4344

4445
#else
4546

46-
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
47+
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
4748
struct DynLoad__##__name { \
4849
template <typename... Args> \
4950
auto operator()(Args... args) -> decltype(__name(args...)) { \
5051
return __name(args...); \
5152
} \
52-
} __name; /* struct DynLoad__##__name */
53+
}; \
54+
extern DynLoad__##__name __name
5355

5456
#endif
5557

5658
/**
5759
* include all needed cudnn functions in HPPL
5860
* different cudnn version has different interfaces
5961
**/
60-
// clang-format off
61-
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
62-
__macro(cudnnSetTensor4dDescriptor) \
63-
__macro(cudnnSetTensor4dDescriptorEx) \
64-
__macro(cudnnGetConvolutionNdForwardOutputDim) \
65-
__macro(cudnnGetConvolutionForwardAlgorithm) \
66-
__macro(cudnnCreateTensorDescriptor) \
67-
__macro(cudnnDestroyTensorDescriptor) \
68-
__macro(cudnnCreateFilterDescriptor) \
69-
__macro(cudnnSetFilter4dDescriptor) \
70-
__macro(cudnnSetPooling2dDescriptor) \
71-
__macro(cudnnDestroyFilterDescriptor) \
72-
__macro(cudnnCreateConvolutionDescriptor) \
73-
__macro(cudnnCreatePoolingDescriptor) \
74-
__macro(cudnnDestroyPoolingDescriptor) \
75-
__macro(cudnnSetConvolution2dDescriptor) \
76-
__macro(cudnnDestroyConvolutionDescriptor) \
77-
__macro(cudnnCreate) \
78-
__macro(cudnnDestroy) \
79-
__macro(cudnnSetStream) \
80-
__macro(cudnnActivationForward) \
81-
__macro(cudnnConvolutionForward) \
82-
__macro(cudnnConvolutionBackwardBias) \
83-
__macro(cudnnGetConvolutionForwardWorkspaceSize) \
84-
__macro(cudnnTransformTensor) \
85-
__macro(cudnnPoolingForward) \
86-
__macro(cudnnPoolingBackward) \
87-
__macro(cudnnSoftmaxBackward) \
88-
__macro(cudnnSoftmaxForward) \
89-
__macro(cudnnGetVersion) \
90-
__macro(cudnnGetErrorString)
91-
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
92-
93-
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
94-
__macro(cudnnAddTensor) \
95-
__macro(cudnnConvolutionBackwardData) \
96-
__macro(cudnnConvolutionBackwardFilter)
97-
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)
62+
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
63+
__macro(cudnnSetTensor4dDescriptor); \
64+
__macro(cudnnSetTensor4dDescriptorEx); \
65+
__macro(cudnnGetConvolutionNdForwardOutputDim); \
66+
__macro(cudnnGetConvolutionForwardAlgorithm); \
67+
__macro(cudnnCreateTensorDescriptor); \
68+
__macro(cudnnDestroyTensorDescriptor); \
69+
__macro(cudnnCreateFilterDescriptor); \
70+
__macro(cudnnSetFilter4dDescriptor); \
71+
__macro(cudnnSetPooling2dDescriptor); \
72+
__macro(cudnnDestroyFilterDescriptor); \
73+
__macro(cudnnCreateConvolutionDescriptor); \
74+
__macro(cudnnCreatePoolingDescriptor); \
75+
__macro(cudnnDestroyPoolingDescriptor); \
76+
__macro(cudnnSetConvolution2dDescriptor); \
77+
__macro(cudnnDestroyConvolutionDescriptor); \
78+
__macro(cudnnCreate); \
79+
__macro(cudnnDestroy); \
80+
__macro(cudnnSetStream); \
81+
__macro(cudnnActivationForward); \
82+
__macro(cudnnConvolutionForward); \
83+
__macro(cudnnConvolutionBackwardBias); \
84+
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
85+
__macro(cudnnTransformTensor); \
86+
__macro(cudnnPoolingForward); \
87+
__macro(cudnnPoolingBackward); \
88+
__macro(cudnnSoftmaxBackward); \
89+
__macro(cudnnSoftmaxForward); \
90+
__macro(cudnnGetVersion); \
91+
__macro(cudnnGetErrorString);
92+
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
93+
94+
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
95+
__macro(cudnnAddTensor); \
96+
__macro(cudnnConvolutionBackwardData); \
97+
__macro(cudnnConvolutionBackwardFilter);
98+
CUDNN_DNN_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
9899

99100
// APIs available after R3:
100101
#if CUDNN_VERSION >= 3000
101-
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
102-
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
103-
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
104-
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
105-
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
106-
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
107-
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
102+
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
103+
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \
104+
__macro(cudnnGetConvolutionBackwardDataAlgorithm); \
105+
__macro(cudnnGetConvolutionBackwardFilterAlgorithm); \
106+
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize);
107+
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
108108
#endif
109109

110-
111110
// APIs available after R4:
112111
#if CUDNN_VERSION >= 4007
113-
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
114-
__macro(cudnnBatchNormalizationForwardTraining) \
115-
__macro(cudnnBatchNormalizationForwardInference) \
116-
__macro(cudnnBatchNormalizationBackward)
117-
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
118-
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
112+
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
113+
__macro(cudnnBatchNormalizationForwardTraining); \
114+
__macro(cudnnBatchNormalizationForwardInference); \
115+
__macro(cudnnBatchNormalizationBackward);
116+
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
119117
#endif
120118

121119
// APIs in R5
122120
#if CUDNN_VERSION >= 5000
123-
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
124-
__macro(cudnnCreateActivationDescriptor) \
125-
__macro(cudnnSetActivationDescriptor) \
126-
__macro(cudnnGetActivationDescriptor) \
127-
__macro(cudnnDestroyActivationDescriptor)
128-
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
129-
#undef CUDNN_DNN_ROUTINE_EACH_R5
121+
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
122+
__macro(cudnnCreateActivationDescriptor); \
123+
__macro(cudnnSetActivationDescriptor); \
124+
__macro(cudnnGetActivationDescriptor); \
125+
__macro(cudnnDestroyActivationDescriptor);
126+
CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
130127
#endif
131128

132-
#undef CUDNN_DNN_ROUTINE_EACH
133-
// clang-format on
134129
} // namespace dynload
135130
} // namespace platform
136131
} // namespace paddle

0 commit comments

Comments
 (0)