-
Notifications
You must be signed in to change notification settings - Fork 5.9k
implement DeviceContext #2709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement DeviceContext #2709
Changes from 8 commits
e876477
1ba4cb8
cdfa098
5acaffb
ab56c96
abbed1d
f9ae741
0c13b23
c7bdbdb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| #include "paddle/platform/cuda.h" | ||
| #define EIGEN_USE_GPU | ||
| #endif | ||
|
|
||
| #include "paddle/framework/enforce.h" | ||
| #include "paddle/platform/dynload/cublas.h" | ||
| #include "paddle/platform/dynload/cudnn.h" | ||
| #include "paddle/platform/dynload/curand.h" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above three lines should also be included between
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, logically above three header files should be included between macros. |
||
| #include "paddle/platform/place.h" | ||
| #include "unsupported/Eigen/CXX11/Tensor" | ||
|
|
||
| namespace paddle { | ||
| namespace platform { | ||
|
|
||
| class DeviceContext { | ||
| public: | ||
| virtual ~DeviceContext() {} | ||
| }; | ||
|
|
||
| class CpuDeviceContext : public DeviceContext { | ||
|
||
| Eigen::DefaultDevice eigen_device() { | ||
| if (!eigen_device_) { | ||
| eigen_device_ = new Eigen::DefaultDevice(); | ||
| } | ||
|
||
| return *eigen_device_; | ||
| } | ||
|
|
||
| private: | ||
| Eigen::DefaultDevice* eigen_device_{nullptr}; | ||
| }; | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here to add
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code that calls DeviceGuard or CudaDeviceContext must have WITH_GPU set 1.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are a few things to consider when dealing with GPU and CPU mixed code.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this suggestion is useful.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merge this pr temporarily. And I will consider the design of DeviceContext combining with Operator interface. And I will follow advices of @hedaoyuan later. |
||
| class DeviceGuard { | ||
| public: | ||
| explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { | ||
| if (previous_ != new_place) { | ||
| paddle::platform::SetDeviceId(new_place.device); | ||
| } | ||
| } | ||
|
|
||
| ~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); } | ||
|
|
||
| private: | ||
| GPUPlace previous_; | ||
| }; | ||
|
|
||
| class CudaDeviceContext : public DeviceContext { | ||
|
||
| public: | ||
| explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { | ||
| DeviceGuard guard(gpu_place_); | ||
| paddle::platform::throw_on_error(cudaStreamCreate(&stream_), | ||
| "cudaStreamCreate failed"); | ||
| eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); | ||
| eigen_device_ = new Eigen::GpuDevice(eigen_stream_); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we will use the CUDA implementation in Eigen. If not decide to use it, I think the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do not use CUDA implementation in Eigen, then we will write CUDA kernels for every operators. Just like caffe2. |
||
| } | ||
|
|
||
| void Wait() { | ||
| paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), | ||
| "cudaStreamSynchronize failed"); | ||
| } | ||
|
|
||
| cudaStream_t stream() { return stream_; } | ||
|
||
|
|
||
| Eigen::GpuDevice eigen_device() { return *eigen_device_; } | ||
|
|
||
| cublasHandle_t cublas_handle() { | ||
| if (!blas_handle_) { | ||
| DeviceGuard guard(gpu_place_); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == | ||
|
||
| CUBLAS_STATUS_SUCCESS, | ||
| "cublasCreate failed"); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( | ||
| blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, | ||
| "cublasSetStream failed"); | ||
| } | ||
| return blas_handle_; | ||
| } | ||
|
|
||
| cudnnHandle_t cudnn_handle() { | ||
| if (!dnn_handle_) { | ||
| DeviceGuard guard(gpu_place_); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == | ||
| CUDNN_STATUS_SUCCESS, | ||
| "cudnnCreate failed"); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( | ||
| dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, | ||
| "cudnnSetStream failed"); | ||
| } | ||
| return dnn_handle_; | ||
| } | ||
|
|
||
| curandGenerator_t curand_generator() { | ||
| if (!rand_generator_) { | ||
| DeviceGuard guard(gpu_place_); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( | ||
| &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == | ||
| CURAND_STATUS_SUCCESS, | ||
| "curandCreateGenerator failed"); | ||
| PADDLE_ENFORCE( | ||
| paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( | ||
| rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, | ||
| "curandSetPseudoRandomGeneratorSeed failed"); | ||
| PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( | ||
| rand_generator_, stream_) == CURAND_STATUS_SUCCESS, | ||
| "curandSetStream failed"); | ||
| } | ||
| return rand_generator_; | ||
| } | ||
|
|
||
| ~CudaDeviceContext() { | ||
| Wait(); | ||
| if (blas_handle_) { | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == | ||
| CUBLAS_STATUS_SUCCESS, | ||
| "cublasDestroy failed"); | ||
| } | ||
|
|
||
| if (dnn_handle_) { | ||
| PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == | ||
| CUDNN_STATUS_SUCCESS, | ||
| "cudnnDestroy failed"); | ||
| } | ||
|
|
||
| if (rand_generator_) { | ||
| PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( | ||
| rand_generator_) == CURAND_STATUS_SUCCESS, | ||
| "curandDestroyGenerator failed"); | ||
| } | ||
|
|
||
| delete eigen_stream_; | ||
| delete eigen_device_; | ||
|
|
||
| paddle::platform::throw_on_error(cudaStreamDestroy(stream_), | ||
| "cudaStreamDestroy failed"); | ||
| } | ||
|
|
||
| private: | ||
| GPUPlace gpu_place_; | ||
| cudaStream_t stream_; | ||
|
|
||
| Eigen::CudaStreamDevice* eigen_stream_; | ||
| Eigen::GpuDevice* eigen_device_; | ||
|
|
||
| cublasHandle_t blas_handle_{nullptr}; | ||
|
|
||
| cudnnHandle_t dnn_handle_{nullptr}; | ||
|
|
||
| int random_seed_; | ||
| curandGenerator_t rand_generator_{nullptr}; | ||
| }; | ||
| #endif | ||
| } // namespace platform | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/platform/device_context.h" | ||
| #include "gtest/gtest.h" | ||
|
|
||
| TEST(DeviceContext, CudaDevice) { | ||
| int count = paddle::platform::GetDeviceCount(); | ||
| for (int i = 0; i < count; i++) { | ||
| paddle::platform::CudaDeviceContext* device_context = | ||
| new paddle::platform::CudaDeviceContext(i); | ||
| __attribute__((unused)) Eigen::GpuDevice gpu_device = | ||
| device_context->eigen_device(); | ||
| __attribute__((unused)) cudnnHandle_t cudnn_handle = | ||
| device_context->cudnn_handle(); | ||
| __attribute__((unused)) cublasHandle_t cublas_handle = | ||
| device_context->cublas_handle(); | ||
| __attribute__((unused)) curandGenerator_t curand_handle = | ||
| device_context->curand_generator(); | ||
| delete device_context; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the
EIGEN_USE_GPUused?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
EIGEN_USE_GPUis used by eigen library. If we want to use Tensor Expression of eigen in GPU, we have to define this marco.