@@ -14,15 +14,14 @@ limitations under the License. */
1414
1515#pragma once
1616
17+ #include " paddle/framework/enforce.h"
1718#ifndef PADDLE_ONLY_CPU
1819#include " paddle/platform/cuda.h"
19- #define EIGEN_USE_GPU
20- #endif
21-
22- #include " paddle/framework/enforce.h"
2320#include " paddle/platform/dynload/cublas.h"
2421#include " paddle/platform/dynload/cudnn.h"
2522#include " paddle/platform/dynload/curand.h"
23+ #define EIGEN_USE_GPU
24+ #endif
2625#include " paddle/platform/place.h"
2726#include " unsupported/Eigen/CXX11/Tensor"
2827
@@ -34,37 +33,27 @@ class DeviceContext {
3433 virtual ~DeviceContext () {}
3534};
3635
37- class CpuDeviceContext : public DeviceContext {
38- Eigen::DefaultDevice eigen_device () {
39- if (!eigen_device_) {
40- eigen_device_ = new Eigen::DefaultDevice ();
41- }
42- return *eigen_device_;
43- }
44-
45- private:
46- Eigen::DefaultDevice* eigen_device_{nullptr };
47- };
36+ class CPUDeviceContext : public DeviceContext {};
4837
4938#ifndef PADDLE_ONLY_CPU
50- class DeviceGuard {
39+ class GPUPlaceGuard {
5140 public:
52- explicit DeviceGuard (GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
41+ explicit GPUPlaceGuard (GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
5342 if (previous_ != new_place) {
5443 paddle::platform::SetDeviceId (new_place.device );
5544 }
5645 }
5746
58- ~DeviceGuard () { paddle::platform::SetDeviceId (previous_.device ); }
47+ ~GPUPlaceGuard () { paddle::platform::SetDeviceId (previous_.device ); }
5948
6049 private:
6150 GPUPlace previous_;
6251};
6352
64- class CudaDeviceContext : public DeviceContext {
53+ class CUDADeviceContext : public DeviceContext {
6554 public:
66- explicit CudaDeviceContext (const GPUPlace gpu_place) : gpu_place_(gpu_place) {
67- DeviceGuard guard (gpu_place_);
55+ explicit CUDADeviceContext (const GPUPlace gpu_place) : gpu_place_(gpu_place) {
56+ GPUPlaceGuard guard (gpu_place_);
6857 paddle::platform::throw_on_error (cudaStreamCreate (&stream_),
6958 " cudaStreamCreate failed" );
7059 eigen_stream_ = new Eigen::CudaStreamDevice (&stream_);
@@ -82,7 +71,7 @@ class CudaDeviceContext : public DeviceContext {
8271
8372 cublasHandle_t cublas_handle () {
8473 if (!blas_handle_) {
85- DeviceGuard guard (gpu_place_);
74+ GPUPlaceGuard guard (gpu_place_);
8675 PADDLE_ENFORCE (paddle::platform::dynload::cublasCreate (&blas_handle_) ==
8776 CUBLAS_STATUS_SUCCESS,
8877 " cublasCreate failed" );
@@ -95,7 +84,7 @@ class CudaDeviceContext : public DeviceContext {
9584
9685 cudnnHandle_t cudnn_handle () {
9786 if (!dnn_handle_) {
98- DeviceGuard guard (gpu_place_);
87+ GPUPlaceGuard guard (gpu_place_);
9988 PADDLE_ENFORCE (paddle::platform::dynload::cudnnCreate (&dnn_handle_) ==
10089 CUDNN_STATUS_SUCCESS,
10190 " cudnnCreate failed" );
@@ -108,7 +97,7 @@ class CudaDeviceContext : public DeviceContext {
10897
10998 curandGenerator_t curand_generator () {
11099 if (!rand_generator_) {
111- DeviceGuard guard (gpu_place_);
100+ GPUPlaceGuard guard (gpu_place_);
112101 PADDLE_ENFORCE (paddle::platform::dynload::curandCreateGenerator (
113102 &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
114103 CURAND_STATUS_SUCCESS,
@@ -124,7 +113,7 @@ class CudaDeviceContext : public DeviceContext {
124113 return rand_generator_;
125114 }
126115
127- ~CudaDeviceContext () {
116+ ~CUDADeviceContext () {
128117 Wait ();
129118 if (blas_handle_) {
130119 PADDLE_ENFORCE (paddle::platform::dynload::cublasDestroy (blas_handle_) ==
0 commit comments