Skip to content

Commit c7bdbdb

Browse files
committed
follow comments
1 parent 0c13b23 commit c7bdbdb

File tree

2 files changed

+25
-36
lines changed

2 files changed

+25
-36
lines changed

paddle/platform/device_context.h

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include "paddle/framework/enforce.h"
1718
#ifndef PADDLE_ONLY_CPU
1819
#include "paddle/platform/cuda.h"
19-
#define EIGEN_USE_GPU
20-
#endif
21-
22-
#include "paddle/framework/enforce.h"
2320
#include "paddle/platform/dynload/cublas.h"
2421
#include "paddle/platform/dynload/cudnn.h"
2522
#include "paddle/platform/dynload/curand.h"
23+
#define EIGEN_USE_GPU
24+
#endif
2625
#include "paddle/platform/place.h"
2726
#include "unsupported/Eigen/CXX11/Tensor"
2827

@@ -34,37 +33,27 @@ class DeviceContext {
3433
virtual ~DeviceContext() {}
3534
};
3635

37-
class CpuDeviceContext : public DeviceContext {
38-
Eigen::DefaultDevice eigen_device() {
39-
if (!eigen_device_) {
40-
eigen_device_ = new Eigen::DefaultDevice();
41-
}
42-
return *eigen_device_;
43-
}
44-
45-
private:
46-
Eigen::DefaultDevice* eigen_device_{nullptr};
47-
};
36+
class CPUDeviceContext : public DeviceContext {};
4837

4938
#ifndef PADDLE_ONLY_CPU
50-
class DeviceGuard {
39+
class GPUPlaceGuard {
5140
public:
52-
explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
41+
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
5342
if (previous_ != new_place) {
5443
paddle::platform::SetDeviceId(new_place.device);
5544
}
5645
}
5746

58-
~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); }
47+
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
5948

6049
private:
6150
GPUPlace previous_;
6251
};
6352

64-
class CudaDeviceContext : public DeviceContext {
53+
class CUDADeviceContext : public DeviceContext {
6554
public:
66-
explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
67-
DeviceGuard guard(gpu_place_);
55+
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
56+
GPUPlaceGuard guard(gpu_place_);
6857
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
6958
"cudaStreamCreate failed");
7059
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
@@ -82,7 +71,7 @@ class CudaDeviceContext : public DeviceContext {
8271

8372
cublasHandle_t cublas_handle() {
8473
if (!blas_handle_) {
85-
DeviceGuard guard(gpu_place_);
74+
GPUPlaceGuard guard(gpu_place_);
8675
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
8776
CUBLAS_STATUS_SUCCESS,
8877
"cublasCreate failed");
@@ -95,7 +84,7 @@ class CudaDeviceContext : public DeviceContext {
9584

9685
cudnnHandle_t cudnn_handle() {
9786
if (!dnn_handle_) {
98-
DeviceGuard guard(gpu_place_);
87+
GPUPlaceGuard guard(gpu_place_);
9988
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
10089
CUDNN_STATUS_SUCCESS,
10190
"cudnnCreate failed");
@@ -108,7 +97,7 @@ class CudaDeviceContext : public DeviceContext {
10897

10998
curandGenerator_t curand_generator() {
11099
if (!rand_generator_) {
111-
DeviceGuard guard(gpu_place_);
100+
GPUPlaceGuard guard(gpu_place_);
112101
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
113102
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
114103
CURAND_STATUS_SUCCESS,
@@ -124,7 +113,7 @@ class CudaDeviceContext : public DeviceContext {
124113
return rand_generator_;
125114
}
126115

127-
~CudaDeviceContext() {
116+
~CUDADeviceContext() {
128117
Wait();
129118
if (blas_handle_) {
130119
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==

paddle/platform/device_context_test.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@ limitations under the License. */
1515
#include "paddle/platform/device_context.h"
1616
#include "gtest/gtest.h"
1717

18-
TEST(DeviceContext, CudaDevice) {
18+
TEST(CUDADeviceContext, Init) {
1919
int count = paddle::platform::GetDeviceCount();
2020
for (int i = 0; i < count; i++) {
21-
paddle::platform::CudaDeviceContext* device_context =
22-
new paddle::platform::CudaDeviceContext(i);
23-
__attribute__((unused)) Eigen::GpuDevice gpu_device =
24-
device_context->eigen_device();
25-
__attribute__((unused)) cudnnHandle_t cudnn_handle =
26-
device_context->cudnn_handle();
27-
__attribute__((unused)) cublasHandle_t cublas_handle =
28-
device_context->cublas_handle();
29-
__attribute__((unused)) curandGenerator_t curand_handle =
30-
device_context->curand_generator();
21+
paddle::platform::CUDADeviceContext* device_context =
22+
new paddle::platform::CUDADeviceContext(i);
23+
Eigen::GpuDevice gpu_device = device_context->eigen_device();
24+
ASSERT_NE(nullptr, gpu_device.stream());
25+
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
26+
ASSERT_NE(nullptr, cudnn_handle);
27+
cublasHandle_t cublas_handle = device_context->cublas_handle();
28+
ASSERT_NE(nullptr, cublas_handle);
29+
curandGenerator_t curand_handle = device_context->curand_generator();
30+
ASSERT_NE(nullptr, curand_handle);
3131
delete device_context;
3232
}
3333
}

0 commit comments

Comments
 (0)